diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..82a7236bf03797a01d3ad6e6a60cdd8cdf6d0c1a --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +__pycache__ +build +*.so +runs \ No newline at end of file diff --git a/README.md b/README.md index 587f35499d53dab50cbb593a5974f4643e9d558d..14234d8508d048efcfff1a305dae7ce7749aced2 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,11 @@ --- -title: 3DTopia XL +title: 3DTopia-XL emoji: 🌖 colorFrom: green colorTo: pink sdk: gradio sdk_version: 4.41.0 +python_version: 3.9 app_file: app.py pinned: false --- diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..d509be562a20f610f4250deab89b5f53c350b761 --- /dev/null +++ b/app.py @@ -0,0 +1,209 @@ +import os +import imageio +import numpy as np + +os.system("bash install.sh") + +from omegaconf import OmegaConf +import tqdm +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms.functional as TF +import rembg +import gradio as gr +from dva.io import load_from_config +from dva.ray_marcher import RayMarcher +from dva.visualize import visualize_primvolume, visualize_video_primvolume +from inference import remove_background, resize_foreground, extract_texmesh +from models.diffusion import create_diffusion +from huggingface_hub import hf_hub_download +ckpt_path = hf_hub_download(repo_id="frozenburning/3DTopia-XL", filename="model_sview_dit_fp16.pt") +vae_ckpt_path = hf_hub_download(repo_id="frozenburning/3DTopia-XL", filename="model_vae_fp16.pt") + +GRADIO_PRIM_VIDEO_PATH = 'prim.mp4' +GRADIO_RGB_VIDEO_PATH = 'rgb.mp4' +GRADIO_MAT_VIDEO_PATH = 'mat.mp4' +GRADIO_GLB_PATH = 'pbr_mesh.glb' +CONFIG_PATH = "./configs/inference_dit.yml" + +config = OmegaConf.load(CONFIG_PATH) +config.checkpoint_path = ckpt_path +config.model.vae_checkpoint_path = vae_ckpt_path +# model +model = load_from_config(config.model.generator) +state_dict = torch.load(config.checkpoint_path, map_location='cpu') +model.load_state_dict(state_dict['ema']) +vae = load_from_config(config.model.vae) +vae_state_dict = torch.load(config.model.vae_checkpoint_path, map_location='cpu') +vae.load_state_dict(vae_state_dict['model_state_dict']) +conditioner = load_from_config(config.model.conditioner) + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +vae = vae.to(device) +conditioner = conditioner.to(device) +model = model.to(device) +model.eval() + +amp = True +precision_dtype = torch.float16 + +rm = RayMarcher( + config.image_height, + config.image_width, + **config.rm, +).to(device) + +perchannel_norm = False +if "latent_mean" in config.model: + latent_mean = torch.Tensor(config.model.latent_mean)[None, None, :].to(device) + latent_std = torch.Tensor(config.model.latent_std)[None, None, :].to(device) + assert latent_mean.shape[-1] == config.model.generator.in_channels + perchannel_norm = True + +config.diffusion.pop("timestep_respacing") +config.model.pop("vae") +config.model.pop("vae_checkpoint_path") +config.model.pop("conditioner") +config.model.pop("generator") +config.model.pop("latent_nf") +config.model.pop("latent_mean") +config.model.pop("latent_std") +model_primx = load_from_config(config.model) +# load rembg +rembg_session = rembg.new_session() + +# process function +def process(input_image, input_num_steps=25, input_seed=42, input_cfg=6.0): + # seed + torch.manual_seed(input_seed) + + os.makedirs(config.output_dir, exist_ok=True) + output_rgb_video_path = os.path.join(config.output_dir, GRADIO_RGB_VIDEO_PATH) + output_prim_video_path = os.path.join(config.output_dir, GRADIO_PRIM_VIDEO_PATH) + output_mat_video_path = os.path.join(config.output_dir, GRADIO_MAT_VIDEO_PATH) + output_glb_path = os.path.join(config.output_dir, GRADIO_GLB_PATH) + + diffusion = create_diffusion(timestep_respacing=respacing, **config.diffusion) + sample_fn = diffusion.ddim_sample_loop_progressive + fwd_fn = model.forward_with_cfg + + # text-conditioned + if input_image is None: + raise NotImplementedError + # image-conditioned (may also input text, but no text usually works too) + else: + input_image = remove_background(input_image, rembg_session) + input_image = resize_foreground(input_image, 0.85) + raw_image = np.array(input_image) + mask = (raw_image[..., -1][..., None] > 0) * 1 + raw_image = raw_image[..., :3] * mask + input_cond = torch.from_numpy(np.array(raw_image)[None, ...]).to(device) + + with torch.no_grad(): + latent = torch.randn(1, config.model.num_prims, 1, 4, 4, 4) + batch = {} + inf_bs = 1 + inf_x = torch.randn(inf_bs, config.model.num_prims, 68).to(device) + y = conditioner.encoder(input_cond) + model_kwargs = dict(y=y[:inf_bs, ...], precision_dtype=precision_dtype, enable_amp=amp) + if input_cfg >= 0: + model_kwargs['cfg_scale'] = input_cfg + for samples in sample_fn(fwd_fn, inf_x.shape, inf_x, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device): + final_samples = samples + recon_param = final_samples["sample"].reshape(inf_bs, config.model.num_prims, -1) + if perchannel_norm: + recon_param = recon_param / config.model.latent_nf * latent_std + latent_mean + recon_srt_param = recon_param[:, :, 0:4] + recon_feat_param = recon_param[:, :, 4:] # [8, 2048, 64] + recon_feat_param_list = [] + # one-by-one to avoid oom + for inf_bidx in range(inf_bs): + if not perchannel_norm: + decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:]) / config.model.latent_nf) + else: + decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:])) + recon_feat_param_list.append(decoded.detach()) + recon_feat_param = torch.concat(recon_feat_param_list, dim=0) + # invert normalization + if not perchannel_norm: + recon_srt_param[:, :, 0:1] = (recon_srt_param[:, :, 0:1] / 10) + 0.05 + recon_feat_param[:, 0:1, ...] /= 5. + recon_feat_param[:, 1:, ...] = (recon_feat_param[:, 1:, ...] + 1) / 2. + recon_feat_param = recon_feat_param.reshape(inf_bs, config.model.num_prims, -1) + recon_param = torch.concat([recon_srt_param, recon_feat_param], dim=-1) + visualize_video_primvolume(config.output_dir, batch, recon_param, 60, rm, device) + prim_params = {'srt_param': recon_srt_param[0].detach().cpu(), 'feat_param': recon_feat_param[0].detach().cpu()} + torch.save({'model_state_dict': prim_params}, "{}/denoised.pt".format(config.output_dir)) + + # exporting GLB mesh + denoise_param_path = os.path.join(config.output_dir, 'denoised.pt') + primx_ckpt_weight = torch.load(denoise_param_path, map_location='cpu')['model_state_dict'] + model_primx.load_state_dict(ckpt_weight) + model_primx.to(device) + model_primx.eval() + with torch.no_grad(): + model_primx.srt_param[:, 1:4] *= 0.85 + extract_texmesh(config.inference, model_primx, output_glb_path, device) + + return output_rgb_video_path, output_prim_video_path, output_mat_video_path, output_glb_path + +# gradio UI +_TITLE = '''3DTopia-XL''' + +_DESCRIPTION = ''' +
+ + +
+ +* Now we offer 1) single image conditioned model, we will release 2) multiview images conditioned model and 3) pure text conditioned model in the future! +* If you find the output unsatisfying, try using different seeds! +''' + +block = gr.Blocks(title=_TITLE).queue() +with block: + with gr.Row(): + with gr.Column(scale=1): + gr.Markdown('# ' + _TITLE) + gr.Markdown(_DESCRIPTION) + + with gr.Row(variant='panel'): + with gr.Column(scale=1): + # input image + input_image = gr.Image(label="image", type='pil') + # inference steps + input_num_steps = gr.Slider(label="inference steps", minimum=1, maximum=100, step=1, value=25) + # random seed + input_cfg = gr.Slider(label="CFG scale", minimum=0, maximum=15, step=1, value=6) + # random seed + input_seed = gr.Slider(label="random seed", minimum=0, maximum=100000, step=1, value=42) + # gen button + button_gen = gr.Button("Generate") + + with gr.Column(scale=1): + with gr.Tab("Video"): + # final video results + output_rgb_video = gr.Video(label="video") + output_prim_video = gr.Video(label="video") + output_mat_video = gr.Video(label="video") + with gr.Tab("GLB"): + # glb file + output_glb = gr.File(label="glb") + + button_gen.click(process, inputs=[input_image, input_num_steps, input_seed, input_cfg], outputs=[output_rgb_video, output_prim_video, output_mat_video, output_glb]) + + gr.Examples( + examples=[ + "assets/examples/fruit_elephant.jpg", + "assets/examples/mei_ling_panda.png", + "assets/examples/shuai_panda_notail.png", + ], + inputs=[input_image], + outputs=[output_rgb_video, output_prim_video, output_mat_video, output_glb], + fn=lambda x: process(input_image=x), + cache_examples=False, + label='Single Image to 3D PBR Asset' + ) + +block.launch(server_name="0.0.0.0", share=True) \ No newline at end of file diff --git a/assets/examples/blue_cat.png b/assets/examples/blue_cat.png new file mode 100644 index 0000000000000000000000000000000000000000..aa933e07cc2d6b058900730649637d75be19ef5f Binary files /dev/null and b/assets/examples/blue_cat.png differ diff --git a/assets/examples/bubble_mart_blue.png b/assets/examples/bubble_mart_blue.png new file mode 100644 index 0000000000000000000000000000000000000000..af870322d4a8a2f237546fbea9560bb8e5f50364 Binary files /dev/null and b/assets/examples/bubble_mart_blue.png differ diff --git a/assets/examples/bulldog.png b/assets/examples/bulldog.png new file mode 100644 index 0000000000000000000000000000000000000000..16c598a8133643898408ea806b69d5b18c53be7d Binary files /dev/null and b/assets/examples/bulldog.png differ diff --git a/assets/examples/ceramic.png b/assets/examples/ceramic.png new file mode 100644 index 0000000000000000000000000000000000000000..46a2a336ea869397376bab68cbc45c690cc6c617 Binary files /dev/null and b/assets/examples/ceramic.png differ diff --git a/assets/examples/chair_watermelon.png b/assets/examples/chair_watermelon.png new file mode 100644 index 0000000000000000000000000000000000000000..52b39917abcbd2f1eef9b7c8cf9aa602bddde1bf Binary files /dev/null and b/assets/examples/chair_watermelon.png differ diff --git a/assets/examples/cup_rgba.png b/assets/examples/cup_rgba.png new file mode 100644 index 0000000000000000000000000000000000000000..e730cdd960b8e552dea17959bd54f52cb6f941ce Binary files /dev/null and b/assets/examples/cup_rgba.png differ diff --git a/assets/examples/cute_horse.jpg b/assets/examples/cute_horse.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ec8807d313b983e3cc34ee89bbf3f312d6ce66eb Binary files /dev/null and b/assets/examples/cute_horse.jpg differ diff --git a/assets/examples/earphone.jpg b/assets/examples/earphone.jpg new file mode 100644 index 0000000000000000000000000000000000000000..498e4196b0d68f8809d049e7178b80592a31a0a2 Binary files /dev/null and b/assets/examples/earphone.jpg differ diff --git a/assets/examples/firedragon.png b/assets/examples/firedragon.png new file mode 100644 index 0000000000000000000000000000000000000000..6d6c54180f5a2d362c5eb9b54a4ef3a3222516eb Binary files /dev/null and b/assets/examples/firedragon.png differ diff --git a/assets/examples/fox.jpg b/assets/examples/fox.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1f2efc1c3a9c4ad8f36ad93082c124c91a6e9ef7 Binary files /dev/null and b/assets/examples/fox.jpg differ diff --git a/assets/examples/fruit_elephant.jpg b/assets/examples/fruit_elephant.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ef8eaf3b88ae0a38272b34802fe40032055afa58 Binary files /dev/null and b/assets/examples/fruit_elephant.jpg differ diff --git a/assets/examples/hatsune_miku.png b/assets/examples/hatsune_miku.png new file mode 100644 index 0000000000000000000000000000000000000000..2fecf005fdd56a396c4894256fbb98fcc1c4dd8f Binary files /dev/null and b/assets/examples/hatsune_miku.png differ diff --git a/assets/examples/ikun_rgba.png b/assets/examples/ikun_rgba.png new file mode 100644 index 0000000000000000000000000000000000000000..985c37037d83a75954da4e104bbbec4f550130e2 Binary files /dev/null and b/assets/examples/ikun_rgba.png differ diff --git a/assets/examples/mailbox.png b/assets/examples/mailbox.png new file mode 100644 index 0000000000000000000000000000000000000000..cce19d56d5a68b7eeec6848dd97a1ca7be30b520 Binary files /dev/null and b/assets/examples/mailbox.png differ diff --git a/assets/examples/mario.png b/assets/examples/mario.png new file mode 100644 index 0000000000000000000000000000000000000000..d9805fdcb31e2f7f036e830e03a045e169319ddf Binary files /dev/null and b/assets/examples/mario.png differ diff --git a/assets/examples/mei_ling_panda.png b/assets/examples/mei_ling_panda.png new file mode 100644 index 0000000000000000000000000000000000000000..e7d5392b77385ff9670579ef4a10aaa5211e8067 Binary files /dev/null and b/assets/examples/mei_ling_panda.png differ diff --git a/assets/examples/mushroom_teapot.jpg b/assets/examples/mushroom_teapot.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a6c767354305f5467a4c0d5f199eee2a120f4501 Binary files /dev/null and b/assets/examples/mushroom_teapot.jpg differ diff --git a/assets/examples/pikachu.png b/assets/examples/pikachu.png new file mode 100644 index 0000000000000000000000000000000000000000..e7579c16957a3e13b80d53cf0a41ddfdfd47b92d Binary files /dev/null and b/assets/examples/pikachu.png differ diff --git a/assets/examples/potplant_rgba.png b/assets/examples/potplant_rgba.png new file mode 100644 index 0000000000000000000000000000000000000000..4aee4cdf0be69650d465faaa8585e01092e41ccb Binary files /dev/null and b/assets/examples/potplant_rgba.png differ diff --git a/assets/examples/seed_frog.png b/assets/examples/seed_frog.png new file mode 100644 index 0000000000000000000000000000000000000000..35524f5a10fd205370ccd5881aa6deab5431e771 Binary files /dev/null and b/assets/examples/seed_frog.png differ diff --git a/assets/examples/shuai_panda_notail.png b/assets/examples/shuai_panda_notail.png new file mode 100644 index 0000000000000000000000000000000000000000..7a4fb91fc6fc2c3432ab42a2b83c1cc80760a3b9 Binary files /dev/null and b/assets/examples/shuai_panda_notail.png differ diff --git a/assets/examples/yellow_duck.png b/assets/examples/yellow_duck.png new file mode 100644 index 0000000000000000000000000000000000000000..02246dd65c52e0b2cf5a6daba96deeab926928ff Binary files /dev/null and b/assets/examples/yellow_duck.png differ diff --git a/configs/inference_dit.yml b/configs/inference_dit.yml new file mode 100644 index 0000000000000000000000000000000000000000..0820949ffd5217cb3d43b9632fc6e4198b15d277 --- /dev/null +++ b/configs/inference_dit.yml @@ -0,0 +1,97 @@ +debug: False +root_data_dir: ./runs +checkpoint_path: +global_seed: 42 + +inference: + input_dir: + ddim: 25 + cfg: 6 + seed: ${global_seed} + precision: fp16 + export_glb: True + decimate: 100000 + mc_resolution: 256 + batch_size: 4096 + remesh: False + +image_height: 518 +image_width: 518 + +model: + class_name: models.primsdf.PrimSDF + num_prims: 2048 + dim_feat: 6 + prim_shape: 8 + init_scale: 0.05 # useless if auto_scale_init == True + sdf2alpha_var: 0.005 + auto_scale_init: True + init_sampling: uniform + vae: + class_name: models.vae3d_dib.VAE + in_channels: ${model.dim_feat} + latent_channels: 1 + out_channels: ${model.vae.in_channels} + down_channels: [32, 256] + mid_attention: True + up_channels: [256, 32] + layers_per_block: 2 + gradient_checkpointing: False + vae_checkpoint_path: + conditioner: + class_name: models.conditioner.image.ImageConditioner + num_prims: ${model.num_prims} + dim_feat: ${model.dim_feat} + prim_shape: ${model.prim_shape} + sample_view: False + encoder_config: + class_name: models.conditioner.image_dinov2.Dinov2Wrapper + model_name: dinov2_vitb14_reg + freeze: True + generator: + class_name: models.dit_crossattn.DiT + seq_length: ${model.num_prims} + in_channels: 68 # equals to model.vae.latent_channels * latent_dim^3 + condition_channels: 768 + hidden_size: 1152 + depth: 28 + num_heads: 16 + attn_proj_bias: True + cond_drop_prob: 0.1 + gradient_checkpointing: False + latent_nf: 1.0 + latent_mean: [ 0.0442, -0.0029, -0.0425, -0.0043, -0.4086, -0.2906, -0.7002, -0.0852, -0.4446, -0.6896, -0.7344, -0.3524, -0.5488, -0.4313, -1.1715, -0.0875, -0.6131, -0.3924, -0.7335, -0.3749, 0.4658, -0.0236, 0.8362, 0.3388, 0.0188, 0.5988, -0.1853, 1.1579, 0.6240, 0.0758, 0.9641, 0.6586, 0.6260, 0.2384, 0.7798, 0.8297, -0.6543, -0.4441, -1.3887, -0.0393, -0.9008, -0.8616, -1.7434, -0.1328, -0.8119, -0.8225, -1.8533, -0.0444, -1.0510, -0.5158, -1.1907, -0.5265, 0.2832, 0.6037, 0.5981, 0.5461, 0.4366, 0.4144, 0.7219, 0.5722, 0.5937, 0.5598, 0.9414, 0.7419, 0.2102, 0.3388, 0.4501, 0.5166] + latent_std: [0.0219, 0.3707, 0.3911, 0.3610, 0.7549, 0.7909, 0.9691, 0.9193, 0.8218, 0.9389, 1.1785, 1.0254, 0.6376, 0.6568, 0.7892, 0.8468, 0.8775, 0.7920, 0.9037, 0.9329, 0.9196, 1.1123, 1.3041, 1.0955, 1.2727, 1.6565, 1.8502, 1.7006, 0.8973, 1.0408, 1.2034, 1.2703, 1.0373, 1.0486, 1.0716, 0.9746, 0.7088, 0.8685, 1.0030, 0.9504, 1.0410, 1.3033, 1.5368, 1.4386, 0.6142, 0.6887, 0.9085, 0.9903, 1.0190, 0.9302, 1.0121, 0.9964, 1.1474, 1.2729, 1.4627, 1.1404, 1.3713, 1.6692, 1.8424, 1.5047, 1.1356, 1.2369, 1.3554, 1.1848, 1.1319, 1.0822, 1.1972, 0.9916] + +diffusion: + timestep_respacing: + noise_schedule: squaredcos_cap_v2 + diffusion_steps: 1000 + parameterization: v + +rm: + volradius: 10000.0 + dt: 1.0 + +optimizer: + class_name: torch.optim.AdamW + lr: 0.0001 + weight_decay: 0 + +scheduler: + class_name: dva.scheduler.CosineWarmupScheduler + warmup_iters: 3000 + max_iters: 200000 + +train: + batch_size: 8 + n_workers: 4 + n_epochs: 1000 + log_every_n_steps: 50 + summary_every_n_steps: 10000 + ckpt_every_n_steps: 10000 + amp: False + precision: tf32 + +tag: 3dtopia-xl-sview +output_dir: ${root_data_dir}/inference/${tag} diff --git a/dva/__init__.py b/dva/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6d09a529b266d66bd8daa5fc6c6c1c655eb10b83 --- /dev/null +++ b/dva/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/dva/attr_dict.py b/dva/attr_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..84b85c9d3c5f6065a919fffa80ea36ed99c7e848 --- /dev/null +++ b/dva/attr_dict.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import json + + +class AttrDict: + def __init__(self, entries): + self.add_entries_(entries) + + def keys(self): + return self.__dict__.keys() + + def values(self): + return self.__dict__.values() + + def __getitem__(self, key): + return self.__dict__[key] + + def __setitem__(self, key, value): + self.__dict__[key] = value + + def __delitem__(self, key): + return self.__dict__.__delitem__(key) + + def __contains__(self, key): + return key in self.__dict__ + + def __repr__(self): + return self.__dict__.__repr__() + + def __getattr__(self, attr): + if attr.startswith("__"): + return self.__getattribute__(attr) + return self.__dict__[attr] + + def items(self): + return self.__dict__.items() + + def __iter__(self): + return iter(self.items()) + + def add_entries_(self, entries, overwrite=True): + for key, value in entries.items(): + if key not in self.__dict__: + if isinstance(value, dict): + self.__dict__[key] = AttrDict(value) + else: + self.__dict__[key] = value + else: + if isinstance(value, dict): + self.__dict__[key].add_entries_(entries=value, overwrite=overwrite) + elif overwrite or self.__dict__[key] is None: + self.__dict__[key] = value + + def serialize(self): + return json.dumps(self, default=self.obj_to_dict, indent=4) + + def obj_to_dict(self, obj): + return obj.__dict__ + + def get(self, key, default=None): + return self.__dict__.get(key, default) diff --git a/dva/geom.py b/dva/geom.py new file mode 100644 index 0000000000000000000000000000000000000000..0dc9aa6bed0e7e440c41384af2a867ebe4f8d131 --- /dev/null +++ b/dva/geom.py @@ -0,0 +1,653 @@ +from typing import Optional +import numpy as np +import torch as th +import torch.nn.functional as F +import torch.nn as nn + +from sklearn.neighbors import KDTree + +import logging + +logger = logging.getLogger(__name__) + +# NOTE: we need pytorch3d primarily for UV rasterization things +from pytorch3d.renderer.mesh.rasterize_meshes import rasterize_meshes +from pytorch3d.structures import Meshes +from typing import Union, Optional, Tuple +import trimesh +from trimesh import Trimesh +from trimesh.triangles import points_to_barycentric + +try: + # pyre-fixme[21]: Could not find module `igl`. + from igl import point_mesh_squared_distance # @manual + + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def closest_point(mesh, points): + """Helper function that mimics trimesh.proximity.closest_point but uses + IGL for faster queries.""" + v = mesh.vertices + vi = mesh.faces + dist, face_idxs, p = point_mesh_squared_distance(points, v, vi) + return p, dist, face_idxs + +except ImportError: + from trimesh.proximity import closest_point + + +def closest_point_barycentrics(v, vi, points): + """Given a 3D mesh and a set of query points, return closest point barycentrics + Args: + v: np.array (float) + [N, 3] mesh vertices + + vi: np.array (int) + [N, 3] mesh triangle indices + + points: np.array (float) + [M, 3] query points + + Returns: + Tuple[approx, barys, interp_idxs, face_idxs] + approx: [M, 3] approximated (closest) points on the mesh + barys: [M, 3] barycentric weights that produce "approx" + interp_idxs: [M, 3] vertex indices for barycentric interpolation + face_idxs: [M] face indices for barycentric interpolation. interp_idxs = vi[face_idxs] + """ + mesh = Trimesh(vertices=v, faces=vi, process=False) + p, _, face_idxs = closest_point(mesh, points) + p = p.reshape((points.shape[0], 3)) + face_idxs = face_idxs.reshape((points.shape[0],)) + barys = points_to_barycentric(mesh.triangles[face_idxs], p) + b0, b1, b2 = np.split(barys, 3, axis=1) + + interp_idxs = vi[face_idxs] + v0 = v[interp_idxs[:, 0]] + v1 = v[interp_idxs[:, 1]] + v2 = v[interp_idxs[:, 2]] + approx = b0 * v0 + b1 * v1 + b2 * v2 + return approx, barys, interp_idxs, face_idxs + +def make_uv_face_index( + vt: th.Tensor, + vti: th.Tensor, + uv_shape: Union[Tuple[int, int], int], + flip_uv: bool = True, + device: Optional[Union[str, th.device]] = None, +): + """Compute a UV-space face index map identifying which mesh face contains each + texel. For texels with no assigned triangle, the index will be -1.""" + + if isinstance(uv_shape, int): + uv_shape = (uv_shape, uv_shape) + + uv_max_shape_ind = uv_shape.index(max(uv_shape)) + uv_min_shape_ind = uv_shape.index(min(uv_shape)) + uv_ratio = uv_shape[uv_max_shape_ind] / uv_shape[uv_min_shape_ind] + + if device is not None: + if isinstance(device, str): + dev = th.device(device) + else: + dev = device + assert dev.type == "cuda" + else: + dev = th.device("cuda") + + vt = 1.0 - vt.clone() + + if flip_uv: + vt = vt.clone() + vt[:, 1] = 1 - vt[:, 1] + vt_pix = 2.0 * vt.to(dev) - 1.0 + vt_pix = th.cat([vt_pix, th.ones_like(vt_pix[:, 0:1])], dim=1) + + vt_pix[:, uv_min_shape_ind] *= uv_ratio + meshes = Meshes(vt_pix[np.newaxis], vti[np.newaxis].to(dev)) + with th.no_grad(): + face_index, _, _, _ = rasterize_meshes( + meshes, uv_shape, faces_per_pixel=1, z_clip_value=0.0, bin_size=0 + ) + face_index = face_index[0, ..., 0] + return face_index + + +def make_uv_vert_index( + vt: th.Tensor, + vi: th.Tensor, + vti: th.Tensor, + uv_shape: Union[Tuple[int, int], int], + flip_uv: bool = True, +): + """Compute a UV-space vertex index map identifying which mesh vertices + comprise the triangle containing each texel. For texels with no assigned + triangle, all indices will be -1. + """ + face_index_map = make_uv_face_index(vt, vti, uv_shape, flip_uv) + vert_index_map = vi[face_index_map.clamp(min=0)] + vert_index_map[face_index_map < 0] = -1 + return vert_index_map.long() + + +def bary_coords(points: th.Tensor, triangles: th.Tensor, eps: float = 1.0e-6): + """Computes barycentric coordinates for a set of 2D query points given + coordintes for the 3 vertices of the enclosing triangle for each point.""" + x = points[:, 0] - triangles[2, :, 0] + x1 = triangles[0, :, 0] - triangles[2, :, 0] + x2 = triangles[1, :, 0] - triangles[2, :, 0] + y = points[:, 1] - triangles[2, :, 1] + y1 = triangles[0, :, 1] - triangles[2, :, 1] + y2 = triangles[1, :, 1] - triangles[2, :, 1] + denom = y2 * x1 - y1 * x2 + n0 = y2 * x - x2 * y + n1 = x1 * y - y1 * x + + # Small epsilon to prevent divide-by-zero error. + denom = th.where(denom >= 0, denom.clamp(min=eps), denom.clamp(max=-eps)) + + bary_0 = n0 / denom + bary_1 = n1 / denom + bary_2 = 1.0 - bary_0 - bary_1 + + return th.stack((bary_0, bary_1, bary_2)) + + +def make_uv_barys( + vt: th.Tensor, + vti: th.Tensor, + uv_shape: Union[Tuple[int, int], int], + flip_uv: bool = True, +): + """Compute a UV-space barycentric map where each texel contains barycentric + coordinates for that texel within its enclosing UV triangle. For texels + with no assigned triangle, all 3 barycentric coordinates will be 0. + """ + if isinstance(uv_shape, int): + uv_shape = (uv_shape, uv_shape) + + if flip_uv: + # Flip here because texture coordinates in some of our topo files are + # stored in OpenGL convention with Y=0 on the bottom of the texture + # unlike numpy/torch arrays/tensors. + vt = vt.clone() + vt[:, 1] = 1 - vt[:, 1] + + face_index_map = make_uv_face_index(vt, vti, uv_shape, flip_uv=False) + vti_map = vti.long()[face_index_map.clamp(min=0)] + + uv_max_shape_ind = uv_shape.index(max(uv_shape)) + uv_min_shape_ind = uv_shape.index(min(uv_shape)) + uv_ratio = uv_shape[uv_max_shape_ind] / uv_shape[uv_min_shape_ind] + vt = vt.clone() + vt = vt * 2 - 1 + vt[:, uv_min_shape_ind] *= uv_ratio + uv_tri_uvs = vt[vti_map].permute(2, 0, 1, 3) + + uv_grid = th.meshgrid( + th.linspace(0.5, uv_shape[0] - 0.5, uv_shape[0]) / uv_shape[0], + th.linspace(0.5, uv_shape[1] - 0.5, uv_shape[1]) / uv_shape[1], + ) + uv_grid = th.stack(uv_grid[::-1], dim=2).to(uv_tri_uvs) + uv_grid = uv_grid * 2 - 1 + uv_grid[..., uv_min_shape_ind] *= uv_ratio + + bary_map = bary_coords(uv_grid.view(-1, 2), uv_tri_uvs.view(3, -1, 2)) + bary_map = bary_map.permute(1, 0).view(uv_shape[0], uv_shape[1], 3) + bary_map[face_index_map < 0] = 0 + return face_index_map, bary_map + + +def index_image_impaint( + index_image: th.Tensor, + bary_image: Optional[th.Tensor] = None, + distance_threshold=100.0, +): + # getting the mask around the indexes? + if len(index_image.shape) == 3: + valid_index = (index_image != -1).any(dim=-1) + elif len(index_image.shape) == 2: + valid_index = index_image != -1 + else: + raise ValueError("`index_image` should be a [H,W] or [H,W,C] image") + + invalid_index = ~valid_index + + device = index_image.device + + valid_ij = th.stack(th.where(valid_index), dim=-1) + invalid_ij = th.stack(th.where(invalid_index), dim=-1) + lookup_valid = KDTree(valid_ij.cpu().numpy()) + + dists, idxs = lookup_valid.query(invalid_ij.cpu()) + + # TODO: try average? + idxs = th.as_tensor(idxs, device=device)[..., 0] + dists = th.as_tensor(dists, device=device)[..., 0] + + dist_mask = dists < distance_threshold + + invalid_border = th.zeros_like(invalid_index) + invalid_border[invalid_index] = dist_mask + + invalid_src_ij = valid_ij[idxs][dist_mask] + invalid_dst_ij = invalid_ij[dist_mask] + + index_image_imp = index_image.clone() + + index_image_imp[invalid_dst_ij[:, 0], invalid_dst_ij[:, 1]] = index_image[ + invalid_src_ij[:, 0], invalid_src_ij[:, 1] + ] + + if bary_image is not None: + bary_image_imp = bary_image.clone() + + bary_image_imp[invalid_dst_ij[:, 0], invalid_dst_ij[:, 1]] = bary_image[ + invalid_src_ij[:, 0], invalid_src_ij[:, 1] + ] + + return index_image_imp, bary_image_imp + return index_image_imp + + +class GeometryModule(nn.Module): + def __init__( + self, + v, + vi, + vt, + vti, + uv_size, + v2uv: Optional[th.Tensor] = None, + flip_uv=False, + impaint=False, + impaint_threshold=100.0, + ): + super().__init__() + + self.register_buffer("v", th.as_tensor(v)) + self.register_buffer("vi", th.as_tensor(vi)) + self.register_buffer("vt", th.as_tensor(vt)) + self.register_buffer("vti", th.as_tensor(vti)) + if v2uv is not None: + self.register_buffer("v2uv", th.as_tensor(v2uv, dtype=th.int64)) + + # TODO: should we just pass topology here? + # self.n_verts = v2uv.shape[0] + self.n_verts = vi.max() + 1 + + self.uv_size = uv_size + + # TODO: can't we just index face_index? + index_image = make_uv_vert_index( + self.vt, self.vi, self.vti, uv_shape=uv_size, flip_uv=flip_uv + ).cpu() + face_index, bary_image = make_uv_barys( + self.vt, self.vti, uv_shape=uv_size, flip_uv=flip_uv + ) + if impaint: + if min(uv_size) >= 1024: + logger.info( + "impainting index image might take a while for sizes >= 1024" + ) + + index_image, bary_image = index_image_impaint( + index_image, bary_image, impaint_threshold + ) + # TODO: we can avoid doing this 2x + face_index = index_image_impaint( + face_index, distance_threshold=impaint_threshold + ) + + self.register_buffer("index_image", index_image.cpu()) + self.register_buffer("bary_image", bary_image.cpu()) + self.register_buffer("face_index_image", face_index.cpu()) + + def render_index_images(self, uv_size, flip_uv=False, impaint=False): + index_image = make_uv_vert_index( + self.vt, self.vi, self.vti, uv_shape=uv_size, flip_uv=flip_uv + ) + face_image, bary_image = make_uv_barys( + self.vt, self.vti, uv_shape=uv_size, flip_uv=flip_uv + ) + + if impaint: + index_image, bary_image = index_image_impaint( + index_image, + bary_image, + ) + + return index_image, face_image, bary_image + + def vn(self, verts): + return vert_normals(verts, self.vi[np.newaxis].to(th.long)) + + def to_uv(self, values): + return values_to_uv(values, self.index_image, self.bary_image) + + def from_uv(self, values_uv): + # TODO: we need to sample this + return sample_uv(values_uv, self.vt, self.v2uv.to(th.long)) + + def rand_sample_3d_uv(self, count, uv_img): + """ + Sample a set of 3D points on the surface of mesh, return corresponding interpolated values in UV space. + + Args: + count - num of 3D points to be sampled + + uv_img - the image in uv space to be sampled, e.g., texture + """ + _mesh = Trimesh(vertices=self.v.detach().cpu().numpy(), faces=self.vi.detach().cpu().numpy(), process=False) + points, _ = trimesh.sample.sample_surface(_mesh, count) + return self.sample_uv_from_3dpts(points, uv_img) + + def sample_uv_from_3dpts(self, points, uv_img): + num_pts = points.shape[0] + approx, barys, interp_idxs, face_idxs = closest_point_barycentrics(self.v.detach().cpu().numpy(), self.vi.detach().cpu().numpy(), points) + interp_uv_coords = self.vt[interp_idxs, :] # [N, 3, 2] + # do bary interp first to get interp_uv_coord in high-reso uv space + target_uv_coords = th.sum(interp_uv_coords * th.from_numpy(barys)[..., None], dim=1).float() + # then directly sample from uv space + sampled_values = sample_uv(values_uv=uv_img.permute(2, 0, 1)[None, ...], uv_coords=target_uv_coords) # [1, count, c] + approx_values = sampled_values[0].reshape(num_pts, uv_img.shape[2]) + return approx_values.numpy(), points + + def vert_sample_uv(self, uv_img): + count = self.v.shape[0] + points = self.v.detach().cpu().numpy() + approx_values, _ = self.sample_uv_from_3dpts(points, uv_img) + return approx_values + + +def sample_uv( + values_uv, + uv_coords, + v2uv: Optional[th.Tensor] = None, + mode: str = "bilinear", + align_corners: bool = True, + flip_uvs: bool = False, +): + batch_size = values_uv.shape[0] + + if flip_uvs: + uv_coords = uv_coords.clone() + uv_coords[:, 1] = 1.0 - uv_coords[:, 1] + + # uv_coords_norm is [1, N, 1, 2] afterwards + uv_coords_norm = (uv_coords * 2.0 - 1.0)[np.newaxis, :, np.newaxis].expand( + batch_size, -1, -1, -1 + ) + # uv_shape = values_uv.shape[-2:] + # uv_max_shape_ind = uv_shape.index(max(uv_shape)) + # uv_min_shape_ind = uv_shape.index(min(uv_shape)) + # uv_ratio = uv_shape[uv_max_shape_ind] / uv_shape[uv_min_shape_ind] + # uv_coords_norm[..., uv_min_shape_ind] *= uv_ratio + + values = ( + F.grid_sample(values_uv, uv_coords_norm, align_corners=align_corners, mode=mode) + .squeeze(-1) + .permute((0, 2, 1)) + ) + + if v2uv is not None: + values_duplicate = values[:, v2uv] + values = values_duplicate.mean(2) + + return values + + +def values_to_uv(values, index_img, bary_img): + uv_size = index_img.shape + index_mask = th.all(index_img != -1, dim=-1) + idxs_flat = index_img[index_mask].to(th.int64) + bary_flat = bary_img[index_mask].to(th.float32) + # NOTE: here we assume + values_flat = th.sum(values[:, idxs_flat].permute(0, 3, 1, 2) * bary_flat, dim=-1) + values_uv = th.zeros( + values.shape[0], + values.shape[-1], + uv_size[0], + uv_size[1], + dtype=values.dtype, + device=values.device, + ) + values_uv[:, :, index_mask] = values_flat + return values_uv + + +def face_normals(v, vi, eps: float = 1e-5): + pts = v[:, vi] + v0 = pts[:, :, 1] - pts[:, :, 0] + v1 = pts[:, :, 2] - pts[:, :, 0] + n = th.cross(v0, v1, dim=-1) + norm = th.norm(n, dim=-1, keepdim=True) + norm[norm < eps] = 1 + n /= norm + return n + + +def vert_normals(v, vi, eps: float = 1.0e-5): + fnorms = face_normals(v, vi) + fnorms = fnorms[:, :, None].expand(-1, -1, 3, -1).reshape(fnorms.shape[0], -1, 3) + vi_flat = vi.view(1, -1).expand(v.shape[0], -1) + vnorms = th.zeros_like(v) + for j in range(3): + vnorms[..., j].scatter_add_(1, vi_flat, fnorms[..., j]) + norm = th.norm(vnorms, dim=-1, keepdim=True) + norm[norm < eps] = 1 + vnorms /= norm + return vnorms + + +def compute_view_cos(verts, faces, camera_pos): + vn = F.normalize(vert_normals(verts, faces), dim=-1) + v2c = F.normalize(verts - camera_pos[:, np.newaxis], dim=-1) + return th.einsum("bnd,bnd->bn", vn, v2c) + + +def compute_tbn(geom, vt, vi, vti): + """Computes tangent, bitangent, and normal vectors given a mesh. + Args: + geom: [N, n_verts, 3] th.Tensor + Vertex positions. + vt: [n_uv_coords, 2] th.Tensor + UV coordinates. + vi: [..., 3] th.Tensor + Face vertex indices. + vti: [..., 3] th.Tensor + Face UV indices. + Returns: + [..., 3] th.Tensors for T, B, N. + """ + + v0 = geom[:, vi[..., 0]] + v1 = geom[:, vi[..., 1]] + v2 = geom[:, vi[..., 2]] + vt0 = vt[vti[..., 0]] + vt1 = vt[vti[..., 1]] + vt2 = vt[vti[..., 2]] + + v01 = v1 - v0 + v02 = v2 - v0 + vt01 = vt1 - vt0 + vt02 = vt2 - vt0 + f = 1.0 / ( + vt01[None, ..., 0] * vt02[None, ..., 1] + - vt01[None, ..., 1] * vt02[None, ..., 0] + ) + tangent = f[..., None] * th.stack( + [ + v01[..., 0] * vt02[None, ..., 1] - v02[..., 0] * vt01[None, ..., 1], + v01[..., 1] * vt02[None, ..., 1] - v02[..., 1] * vt01[None, ..., 1], + v01[..., 2] * vt02[None, ..., 1] - v02[..., 2] * vt01[None, ..., 1], + ], + dim=-1, + ) + tangent = F.normalize(tangent, dim=-1) + normal = F.normalize(th.cross(v01, v02, dim=3), dim=-1) + bitangent = F.normalize(th.cross(tangent, normal, dim=3), dim=-1) + + return tangent, bitangent, normal + + +def compute_v2uv(n_verts, vi, vti, n_max=4): + """Computes mapping from vertex indices to texture indices. + + Args: + vi: [F, 3], triangles + vti: [F, 3], texture triangles + n_max: int, max number of texture locations + + Returns: + [n_verts, n_max], texture indices + """ + v2uv_dict = {} + for i_v, i_uv in zip(vi.reshape(-1), vti.reshape(-1)): + v2uv_dict.setdefault(i_v, set()).add(i_uv) + assert len(v2uv_dict) == n_verts + v2uv = np.zeros((n_verts, n_max), dtype=np.int32) + for i in range(n_verts): + vals = sorted(list(v2uv_dict[i])) + v2uv[i, :] = vals[0] + v2uv[i, : len(vals)] = np.array(vals) + return v2uv + + +def compute_neighbours(n_verts, vi, n_max_values=10): + """Computes first-ring neighbours given vertices and faces.""" + n_vi = vi.shape[0] + + adj = {i: set() for i in range(n_verts)} + for i in range(n_vi): + for idx in vi[i]: + adj[idx] |= set(vi[i]) - set([idx]) + + nbs_idxs = np.tile(np.arange(n_verts)[:, np.newaxis], (1, n_max_values)) + nbs_weights = np.zeros((n_verts, n_max_values), dtype=np.float32) + + for idx in range(n_verts): + n_values = min(len(adj[idx]), n_max_values) + nbs_idxs[idx, :n_values] = np.array(list(adj[idx]))[:n_values] + nbs_weights[idx, :n_values] = -1.0 / n_values + + return nbs_idxs, nbs_weights + + +def make_postex(v, idxim, barim): + return ( + barim[None, :, :, 0, None] * v[:, idxim[:, :, 0]] + + barim[None, :, :, 1, None] * v[:, idxim[:, :, 1]] + + barim[None, :, :, 2, None] * v[:, idxim[:, :, 2]] + ).permute(0, 3, 1, 2) + + +def matrix_to_axisangle(r): + th = th.arccos(0.5 * (r[..., 0, 0] + r[..., 1, 1] + r[..., 2, 2] - 1.0))[..., None] + vec = ( + 0.5 + * th.stack( + [ + r[..., 2, 1] - r[..., 1, 2], + r[..., 0, 2] - r[..., 2, 0], + r[..., 1, 0] - r[..., 0, 1], + ], + dim=-1, + ) + / th.sin(th) + ) + return th, vec + + +def axisangle_to_matrix(rvec): + theta = th.sqrt(1e-5 + th.sum(rvec**2, dim=-1)) + rvec = rvec / theta[..., None] + costh = th.cos(theta) + sinth = th.sin(theta) + return th.stack( + ( + th.stack( + ( + rvec[..., 0] ** 2 + (1.0 - rvec[..., 0] ** 2) * costh, + rvec[..., 0] * rvec[..., 1] * (1.0 - costh) - rvec[..., 2] * sinth, + rvec[..., 0] * rvec[..., 2] * (1.0 - costh) + rvec[..., 1] * sinth, + ), + dim=-1, + ), + th.stack( + ( + rvec[..., 0] * rvec[..., 1] * (1.0 - costh) + rvec[..., 2] * sinth, + rvec[..., 1] ** 2 + (1.0 - rvec[..., 1] ** 2) * costh, + rvec[..., 1] * rvec[..., 2] * (1.0 - costh) - rvec[..., 0] * sinth, + ), + dim=-1, + ), + th.stack( + ( + rvec[..., 0] * rvec[..., 2] * (1.0 - costh) - rvec[..., 1] * sinth, + rvec[..., 1] * rvec[..., 2] * (1.0 - costh) + rvec[..., 0] * sinth, + rvec[..., 2] ** 2 + (1.0 - rvec[..., 2] ** 2) * costh, + ), + dim=-1, + ), + ), + dim=-2, + ) + + +def rotation_interp(r0, r1, alpha): + r0a = r0.view(-1, 3, 3) + r1a = r1.view(-1, 3, 3) + r = th.bmm(r0a.permute(0, 2, 1), r1a).view_as(r0) + + th, rvec = matrix_to_axisangle(r) + rvec = rvec * (alpha * th) + + r = axisangle_to_matrix(rvec) + return th.bmm(r0a, r.view(-1, 3, 3)).view_as(r0) + + +def convert_camera_parameters(Rt, K): + R = Rt[:, :3, :3] + t = -R.permute(0, 2, 1).bmm(Rt[:, :3, 3].unsqueeze(2)).squeeze(2) + return dict( + campos=t, + camrot=R, + focal=K[:, :2, :2], + princpt=K[:, :2, 2], + ) + + +def project_points_multi(p, Rt, K, normalize=False, size=None): + """Project a set of 3D points into multiple cameras with a pinhole model. + Args: + p: [B, N, 3], input 3D points in world coordinates + Rt: [B, NC, 3, 4], extrinsics (where NC is the number of cameras to project to) + K: [B, NC, 3, 3], intrinsics + normalize: bool, whether to normalize coordinates to [-1.0, 1.0] + Returns: + tuple: + - [B, NC, N, 2] - projected points + - [B, NC, N] - their + """ + B, N = p.shape[:2] + NC = Rt.shape[1] + + Rt = Rt.reshape(B * NC, 3, 4) + K = K.reshape(B * NC, 3, 3) + + # [B, N, 3] -> [B * NC, N, 3] + p = p[:, np.newaxis].expand(-1, NC, -1, -1).reshape(B * NC, -1, 3) + p_cam = p @ Rt[:, :3, :3].transpose(-2, -1) + Rt[:, :3, 3][:, np.newaxis] + p_pix = p_cam @ K.transpose(-2, -1) + p_depth = p_pix[:, :, 2:] + p_pix = (p_pix[..., :2] / p_depth).reshape(B, NC, N, 2) + p_depth = p_depth.reshape(B, NC, N) + + if normalize: + assert size is not None + h, w = size + p_pix = ( + 2.0 * p_pix / th.as_tensor([w, h], dtype=th.float32, device=p.device) - 1.0 + ) + return p_pix, p_depth diff --git a/dva/io.py b/dva/io.py new file mode 100644 index 0000000000000000000000000000000000000000..d2875b51a2bfdbd2c00f904b6b732583fcd0b182 --- /dev/null +++ b/dva/io.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import json +import cv2 +import numpy as np +import copy +import importlib +from typing import Any, Dict + +def load_module(module_name, class_name=None, silent: bool = False): + module = importlib.import_module(module_name) + return getattr(module, class_name) if class_name else module + + +def load_class(class_name): + return load_module(*class_name.rsplit(".", 1)) + + +def load_from_config(config, **kwargs): + """Instantiate an object given a config and arguments.""" + assert "class_name" in config and "module_name" not in config + config = copy.deepcopy(config) + class_name = config.pop("class_name") + object_class = load_class(class_name) + return object_class(**config, **kwargs) + + +def load_opencv_calib(extrin_path, intrin_path): + cameras = {} + + fse = cv2.FileStorage() + fse.open(extrin_path, cv2.FileStorage_READ) + + fsi = cv2.FileStorage() + fsi.open(intrin_path, cv2.FileStorage_READ) + + names = [ + fse.getNode("names").at(c).string() for c in range(fse.getNode("names").size()) + ] + + for camera in names: + rot = fse.getNode(f"R_{camera}").mat() + R = fse.getNode(f"Rot_{camera}").mat() + T = fse.getNode(f"T_{camera}").mat() + R_pred = cv2.Rodrigues(rot)[0] + assert np.all(np.isclose(R_pred, R)) + K = fsi.getNode(f"K_{camera}").mat() + cameras[camera] = { + "Rt": np.concatenate([R, T], axis=1).astype(np.float32), + "K": K.astype(np.float32), + } + return cameras \ No newline at end of file diff --git a/dva/layers.py b/dva/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..8d219ecfed8c5e8ad2dfb1f7a1204bffa2ea2344 --- /dev/null +++ b/dva/layers.py @@ -0,0 +1,157 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch as th +import torch.nn as nn + +import numpy as np + +from dva.mvp.models.utils import Conv2dWN, Conv2dWNUB, ConvTranspose2dWNUB, initmod + + +class ConvBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + size, + lrelu_slope=0.2, + kernel_size=3, + padding=1, + wnorm_dim=0, + ): + super().__init__() + + self.conv_resize = Conv2dWN(in_channels, out_channels, kernel_size=1) + self.conv1 = Conv2dWNUB( + in_channels, + in_channels, + kernel_size=kernel_size, + padding=padding, + height=size, + width=size, + ) + + self.lrelu1 = nn.LeakyReLU(lrelu_slope) + self.conv2 = Conv2dWNUB( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=padding, + height=size, + width=size, + ) + self.lrelu2 = nn.LeakyReLU(lrelu_slope) + + def forward(self, x): + x_skip = self.conv_resize(x) + x = self.conv1(x) + x = self.lrelu1(x) + x = self.conv2(x) + x = self.lrelu2(x) + return x + x_skip + + +def tile2d(x, size: int): + """Tile a given set of features into a convolutional map. + + Args: + x: float tensor of shape [N, F] + size: int or a tuple + + Returns: + a feature map [N, F, size[0], size[1]] + """ + # size = size if isinstance(size, tuple) else (size, size) + # NOTE: expecting only int here (!!!) + return x[:, :, np.newaxis, np.newaxis].expand(-1, -1, size, size) + + +def weights_initializer(m, alpha: float = 1.0): + return initmod(m, nn.init.calculate_gain("leaky_relu", alpha)) + + +class UNetWB(nn.Module): + def __init__( + self, + in_channels, + out_channels, + size, + n_init_ftrs=8, + out_scale=0.1, + ): + # super().__init__(*args, **kwargs) + super().__init__() + + self.out_scale = 0.1 + + F = n_init_ftrs + + # TODO: allow changing the size? + self.size = size + + self.down1 = nn.Sequential( + Conv2dWNUB(in_channels, F, self.size // 2, self.size // 2, 4, 2, 1), + nn.LeakyReLU(0.2), + ) + self.down2 = nn.Sequential( + Conv2dWNUB(F, 2 * F, self.size // 4, self.size // 4, 4, 2, 1), + nn.LeakyReLU(0.2), + ) + self.down3 = nn.Sequential( + Conv2dWNUB(2 * F, 4 * F, self.size // 8, self.size // 8, 4, 2, 1), + nn.LeakyReLU(0.2), + ) + self.down4 = nn.Sequential( + Conv2dWNUB(4 * F, 8 * F, self.size // 16, self.size // 16, 4, 2, 1), + nn.LeakyReLU(0.2), + ) + self.down5 = nn.Sequential( + Conv2dWNUB(8 * F, 16 * F, self.size // 32, self.size // 32, 4, 2, 1), + nn.LeakyReLU(0.2), + ) + self.up1 = nn.Sequential( + ConvTranspose2dWNUB( + 16 * F, 8 * F, self.size // 16, self.size // 16, 4, 2, 1 + ), + nn.LeakyReLU(0.2), + ) + self.up2 = nn.Sequential( + ConvTranspose2dWNUB(8 * F, 4 * F, self.size // 8, self.size // 8, 4, 2, 1), + nn.LeakyReLU(0.2), + ) + self.up3 = nn.Sequential( + ConvTranspose2dWNUB(4 * F, 2 * F, self.size // 4, self.size // 4, 4, 2, 1), + nn.LeakyReLU(0.2), + ) + self.up4 = nn.Sequential( + ConvTranspose2dWNUB(2 * F, F, self.size // 2, self.size // 2, 4, 2, 1), + nn.LeakyReLU(0.2), + ) + self.up5 = nn.Sequential( + ConvTranspose2dWNUB(F, F, self.size, self.size, 4, 2, 1), nn.LeakyReLU(0.2) + ) + self.out = Conv2dWNUB( + F + in_channels, out_channels, self.size, self.size, kernel_size=1 + ) + self.apply(lambda x: initmod(x, 0.2)) + initmod(self.out, 1.0) + + def forward(self, x): + x1 = x + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + x6 = self.down5(x5) + # TODO: switch to concat? + x = self.up1(x6) + x5 + x = self.up2(x) + x4 + x = self.up3(x) + x3 + x = self.up4(x) + x2 + x = self.up5(x) + x = th.cat([x, x1], dim=1) + return self.out(x) * self.out_scale diff --git a/dva/losses.py b/dva/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..850475114bbbf06f552646b548d6f926813c0098 --- /dev/null +++ b/dva/losses.py @@ -0,0 +1,239 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn +import torch as th +import numpy as np + +import logging + +from .vgg import VGGLossMasked + +logger = logging.getLogger("dva.{__name__}") + +class DCTLoss(nn.Module): + def __init__(self, weights): + super().__init__() + self.weights = weights + + def forward(self, inputs, preds, iteration=None): + loss_dict = {"loss_total": 0.0} + target = inputs['gt'] + recon = preds['recon'] + posterior = preds['posterior'] + fft_gt = th.view_as_real(th.fft.fft(target.reshape(target.shape[0], -1))) + fft_recon = th.view_as_real(th.fft.fft(recon.reshape(recon.shape[0], -1))) + loss_recon_dct_l1 = th.mean(th.abs(fft_gt - fft_recon)) + loss_recon_l1 = th.mean(th.abs(target - recon)) + loss_kl = posterior.kl().mean() + loss_dict.update(loss_recon_l1=loss_recon_l1, loss_recon_dct_l1=loss_recon_dct_l1, loss_kl=loss_kl) + loss_total = self.weights.recon * loss_recon_dct_l1 + self.weights.kl * loss_kl + + loss_dict["loss_total"] = loss_total + return loss_total, loss_dict + +class VAESepL2Loss(nn.Module): + def __init__(self, weights): + super().__init__() + self.weights = weights + + def forward(self, inputs, preds, iteration=None): + loss_dict = {"loss_total": 0.0} + target = inputs['gt'] + recon = preds['recon'] + posterior = preds['posterior'] + recon_diff = (target - recon) ** 2 + loss_recon_sdf_l1 = th.mean(recon_diff[:, 0:1, ...]) + loss_recon_rgb_l1 = th.mean(recon_diff[:, 1:4, ...]) + loss_recon_mat_l1 = th.mean(recon_diff[:, 4:6, ...]) + loss_kl = posterior.kl().mean() + loss_dict.update(loss_sdf_l1=loss_recon_sdf_l1, loss_rgb_l1=loss_recon_rgb_l1, loss_mat_l1=loss_recon_mat_l1, loss_kl=loss_kl) + loss_total = self.weights.sdf * loss_recon_sdf_l1 + self.weights.rgb * loss_recon_rgb_l1 + self.weights.mat * loss_recon_mat_l1 + if "kl" in self.weights: + loss_total += self.weights.kl * loss_kl + + loss_dict["loss_total"] = loss_total + return loss_total, loss_dict + +class VAESepLoss(nn.Module): + def __init__(self, weights): + super().__init__() + self.weights = weights + + def forward(self, inputs, preds, iteration=None): + loss_dict = {"loss_total": 0.0} + target = inputs['gt'] + recon = preds['recon'] + posterior = preds['posterior'] + recon_diff = th.abs(target - recon) + loss_recon_sdf_l1 = th.mean(recon_diff[:, 0:1, ...]) + loss_recon_rgb_l1 = th.mean(recon_diff[:, 1:4, ...]) + loss_recon_mat_l1 = th.mean(recon_diff[:, 4:6, ...]) + loss_kl = posterior.kl().mean() + loss_dict.update(loss_sdf_l1=loss_recon_sdf_l1, loss_rgb_l1=loss_recon_rgb_l1, loss_mat_l1=loss_recon_mat_l1, loss_kl=loss_kl) + loss_total = self.weights.sdf * loss_recon_sdf_l1 + self.weights.rgb * loss_recon_rgb_l1 + self.weights.mat * loss_recon_mat_l1 + if "kl" in self.weights: + loss_total += self.weights.kl * loss_kl + + loss_dict["loss_total"] = loss_total + return loss_total, loss_dict + +class VAELoss(nn.Module): + def __init__(self, weights): + super().__init__() + self.weights = weights + + def forward(self, inputs, preds, iteration=None): + loss_dict = {"loss_total": 0.0} + target = inputs['gt'] + recon = preds['recon'] + posterior = preds['posterior'] + loss_recon_l1 = th.mean(th.abs(target - recon)) + loss_kl = posterior.kl().mean() + loss_dict.update(loss_recon_l1=loss_recon_l1, loss_kl=loss_kl) + loss_total = self.weights.recon * loss_recon_l1 + self.weights.kl * loss_kl + + loss_dict["loss_total"] = loss_total + return loss_total, loss_dict + +class PrimSDFLoss(nn.Module): + def __init__(self, weights, shape_opt_steps=2000, tex_opt_steps=6000): + super().__init__() + self.weights = weights + self.shape_opt_steps = shape_opt_steps + self.tex_opt_steps = tex_opt_steps + + def forward(self, inputs, preds, iteration=None): + loss_dict = {"loss_total": 0.0} + + if iteration < self.shape_opt_steps: + target_sdf = inputs['sdf'] + sdf = preds['sdf'] + loss_sdf_l1 = th.mean(th.abs(sdf - target_sdf)) + loss_dict.update(loss_sdf_l1=loss_sdf_l1) + loss_total = self.weights.sdf_l1 * loss_sdf_l1 + + prim_scale = preds["prim_scale"] + # we use 1/scale instead of the original 100/scale as our scale is normalized to [-1, 1] cube + if "vol_sum" in self.weights: + loss_prim_vol_sum = th.mean(th.sum(th.prod(1 / prim_scale, dim=-1), dim=-1)) + loss_dict.update(loss_prim_vol_sum=loss_prim_vol_sum) + loss_total += self.weights.vol_sum * loss_prim_vol_sum + + if iteration >= self.shape_opt_steps and iteration < self.tex_opt_steps: + target_tex = inputs['tex'] + tex = preds['tex'] + loss_tex_l1 = th.mean(th.abs(tex - target_tex)) + loss_dict.update(loss_tex_l1=loss_tex_l1) + + loss_total = ( + self.weights.rgb_l1 * loss_tex_l1 + ) + if "mat_l1" in self.weights: + target_mat = inputs['mat'] + mat = preds['mat'] + loss_mat_l1 = th.mean(th.abs(mat - target_mat)) + loss_dict.update(loss_mat_l1=loss_mat_l1) + loss_total += self.weights.mat_l1 * loss_mat_l1 + + if "grad_l2" in self.weights: + loss_grad_l2 = th.mean((preds["grad"] - inputs["grad"]) ** 2) + loss_total += self.weights.grad_l2 * loss_grad_l2 + loss_dict.update(loss_grad_l2=loss_grad_l2) + + loss_dict["loss_total"] = loss_total + return loss_total, loss_dict + + +class TotalMVPLoss(nn.Module): + def __init__(self, weights, assets=None): + super().__init__() + + self.weights = weights + + if "vgg" in self.weights: + self.vgg_loss = VGGLossMasked() + + def forward(self, inputs, preds, iteration=None): + + loss_dict = {"loss_total": 0.0} + + B = inputs["image"].shape + + # rgb + target_rgb = inputs["image"].permute(0, 2, 3, 1) + # removing the mask + target_rgb = target_rgb * inputs["image_mask"][:, 0, :, :, np.newaxis] + + rgb = preds["rgb"] + loss_rgb_mse = th.mean(((rgb - target_rgb) / 16.0) ** 2.0) + loss_dict.update(loss_rgb_mse=loss_rgb_mse) + + alpha = preds["alpha"] + + # mask loss + target_mask = inputs["image_mask"][:, 0].to(th.float32) + loss_mask_mae = th.mean((target_mask - alpha).abs()) + loss_dict.update(loss_mask_mae=loss_mask_mae) + + B = alpha.shape[0] + + # beta prior on opacity + loss_alpha_prior = th.mean( + th.log(0.1 + alpha.reshape(B, -1)) + + th.log(0.1 + 1.0 - alpha.reshape(B, -1)) + - -2.20727 + ) + loss_dict.update(loss_alpha_prior=loss_alpha_prior) + + prim_scale = preds["prim_scale"] + loss_prim_vol_sum = th.mean(th.sum(th.prod(100.0 / prim_scale, dim=-1), dim=-1)) + loss_dict.update(loss_prim_vol_sum=loss_prim_vol_sum) + + loss_total = ( + self.weights.rgb_mse * loss_rgb_mse + + self.weights.mask_mae * loss_mask_mae + + self.weights.alpha_prior * loss_alpha_prior + + self.weights.prim_vol_sum * loss_prim_vol_sum + ) + + if "embs_l2" in self.weights: + loss_embs_l2 = th.sum(th.norm(preds["embs"], dim=1)) + loss_total += self.weights.embs_l2 * loss_embs_l2 + loss_dict.update(loss_embs_l2=loss_embs_l2) + + if "vgg" in self.weights: + loss_vgg = self.vgg_loss( + rgb.permute(0, 3, 1, 2), + target_rgb.permute(0, 3, 1, 2), + inputs["image_mask"], + ) + loss_total += self.weights.vgg * loss_vgg + loss_dict.update(loss_vgg=loss_vgg) + + if "prim_scale_var" in self.weights: + log_prim_scale = th.log(prim_scale) + # NOTE: should we detach this? + log_prim_scale_mean = th.mean(log_prim_scale, dim=1, keepdim=True) + loss_prim_scale_var = th.mean((log_prim_scale - log_prim_scale_mean) ** 2.0) + loss_total += self.weights.prim_scale_var * loss_prim_scale_var + loss_dict.update(loss_prim_scale_var=loss_prim_scale_var) + + loss_dict["loss_total"] = loss_total + + return loss_total, loss_dict + + +def process_losses(loss_dict, reduce=True, detach=True): + """Preprocess the dict of losses outputs.""" + result = { + k.replace("loss_", ""): v for k, v in loss_dict.items() if k.startswith("loss_") + } + if detach: + result = {k: v.detach() for k, v in result.items()} + if reduce: + result = {k: float(v.mean().item()) for k, v in result.items()} + return result diff --git a/dva/mvp/extensions/mvpraymarch/bvh.cu b/dva/mvp/extensions/mvpraymarch/bvh.cu new file mode 100644 index 0000000000000000000000000000000000000000..d203b40eddb12631d92c41d6051824a43da11236 --- /dev/null +++ b/dva/mvp/extensions/mvpraymarch/bvh.cu @@ -0,0 +1,292 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +#include "helper_math.h" + +#include "cudadispatch.h" + +#include "primtransf.h" + +// Expands a 10-bit integer into 30 bits +// by inserting 2 zeros after each bit. +__device__ unsigned int expand_bits(unsigned int v) { + v = (v * 0x00010001u) & 0xFF0000FFu; + v = (v * 0x00000101u) & 0x0F00F00Fu; + v = (v * 0x00000011u) & 0xC30C30C3u; + v = (v * 0x00000005u) & 0x49249249u; + return v; +} + +// Calculates a 30-bit Morton code for the +// given 3D point located within the unit cube [0,1]. +__device__ unsigned int morton3D(float x, float y, float z) { + x = fminf(fmaxf(x * 1024.0f, 0.0f), 1023.0f); + y = fminf(fmaxf(y * 1024.0f, 0.0f), 1023.0f); + z = fminf(fmaxf(z * 1024.0f, 0.0f), 1023.0f); + unsigned int xx = expand_bits((unsigned int)x); + unsigned int yy = expand_bits((unsigned int)y); + unsigned int zz = expand_bits((unsigned int)z); + return xx * 4 + yy * 2 + zz; +} + +template +__global__ void compute_morton_kernel( + int N, int K, + typename PrimTransfT::Data data, + int * code + ) { + const int count = N * K; + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < count; index += blockDim.x * gridDim.x) { + const int k = index % K; + const int n = index / K; + + //float4 c = center[n * K + k]; + float3 c = data.get_center(n, k); + code[n * K + k] = morton3D(c.x, c.y, c.z); + } +} + +__forceinline__ __device__ int delta(int* sortedcodes, int x, int y, int K) { + if (x >= 0 && x <= K - 1 && y >= 0 && y <= K - 1) { + return sortedcodes[x] == sortedcodes[y] ? + 32 + __clz(x ^ y) : + __clz(sortedcodes[x] ^ sortedcodes[y]); + } + return -1; +} + +__forceinline__ __device__ int sign(int x) { + return (int)(x > 0) - (int)(x < 0); +} + +__device__ int find_split( + int* sortedcodes, + int first, + int last, + int K) { + float commonPrefix = delta(sortedcodes, first, last, K); + int split = first; + int step = last - first; + + do { + step = (step + 1) >> 1; // exponential decrease + int newSplit = split + step; // proposed new position + + if (newSplit < last) { + int splitPrefix = delta(sortedcodes, first, newSplit, K); + if (splitPrefix > commonPrefix) { + split = newSplit; // accept proposal + } + } + } while (step > 1); + + return split; +} + +__device__ int2 determine_range(int* sortedcodes, int K, int idx) { + int d = sign(delta(sortedcodes, idx, idx + 1, K) - delta(sortedcodes, idx, idx - 1, K)); + int dmin = delta(sortedcodes, idx, idx - d, K); + int lmax = 2; + while (delta(sortedcodes, idx, idx + lmax * d, K) > dmin) { + lmax = lmax * 2; + } + + int l = 0; + for (int t = lmax / 2; t >= 1; t /= 2) { + if (delta(sortedcodes, idx, idx + (l + t)*d, K) > dmin) { + l += t; + } + } + + int j = idx + l*d; + int2 range; + range.x = min(idx, j); + range.y = max(idx, j); + + return range; +} + +__global__ void build_tree_kernel( + int N, int K, + int * sortedcodes, + int2 * nodechildren, + int * nodeparent) { + const int count = N * (K + K - 1); + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < count; index += blockDim.x * gridDim.x) { + const int k = index % (K + K - 1); + const int n = index / (K + K - 1); + + if (k >= K - 1) { + // leaf + nodechildren[n * (K + K - 1) + k] = make_int2(-(k - (K - 1)) - 1, -(k - (K - 1)) - 2); + } else { + // internal node + + // find out which range of objects the node corresponds to + int2 range = determine_range(sortedcodes + n * K, K, k); + int first = range.x; + int last = range.y; + + // determine where to split the range + int split = find_split(sortedcodes + n * K, first, last, K); + + // select childA + int childa = split == first ? (K - 1) + split : split; + + // select childB + int childb = split + 1 == last ? (K - 1) + split + 1 : split + 1; + + // record parent-child relationships + nodechildren[n * (K + K - 1) + k] = make_int2(childa, childb); + nodeparent[n * (K + K - 1) + childa] = k; + nodeparent[n * (K + K - 1) + childb] = k; + } + } +} + +template +__global__ void compute_aabb_kernel( + int N, int K, + typename PrimTransfT::Data data, + int * sortedobjid, + int2 * nodechildren, + int * nodeparent, + float3 * nodeaabb, + int * atom) { + const int count = N * K; + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < count; index += blockDim.x * gridDim.x) { + const int k = index % K; + const int n = index / K; + + // compute BBOX for leaf + int kk = sortedobjid[n * K + k]; + + float3 pmin; + float3 pmax; + data.compute_aabb(n, kk, pmin, pmax); + + nodeaabb[n * (K + K - 1) * 2 + ((K - 1) + k) * 2 + 0] = pmin; + nodeaabb[n * (K + K - 1) * 2 + ((K - 1) + k) * 2 + 1] = pmax; + + int node = nodeparent[n * (K + K - 1) + ((K - 1) + k)]; + + while (node != -1 && atomicCAS(&atom[n * (K - 1) + node], 0, 1) == 1) { + int2 children = nodechildren[n * (K + K - 1) + node]; + float3 laabbmin = nodeaabb[n * (K + K - 1) * 2 + children.x * 2 + 0]; + float3 laabbmax = nodeaabb[n * (K + K - 1) * 2 + children.x * 2 + 1]; + float3 raabbmin = nodeaabb[n * (K + K - 1) * 2 + children.y * 2 + 0]; + float3 raabbmax = nodeaabb[n * (K + K - 1) * 2 + children.y * 2 + 1]; + + float3 aabbmin = fminf(laabbmin, raabbmin); + float3 aabbmax = fmaxf(laabbmax, raabbmax); + + nodeaabb[n * (K + K - 1) * 2 + node * 2 + 0] = aabbmin; + nodeaabb[n * (K + K - 1) * 2 + node * 2 + 1] = aabbmax; + + node = nodeparent[n * (K + K - 1) + node]; + + __threadfence(); + } + } +} + +void compute_morton_cuda( + int N, int K, + float * primpos, + int * code, + int algorithm, + cudaStream_t stream) { + int count = N * K; + int blocksize = 512; + int gridsize = (count + blocksize - 1) / blocksize; + + std::shared_ptr primtransf_data; + primtransf_data = std::make_shared(PrimTransfSRT::Data{ + PrimTransfDataBase{}, + K, (float3*)primpos, nullptr, + K * 3, nullptr, nullptr, + K, nullptr, nullptr}); + + std::map, int*)>> dispatcher = { + { 0, make_cudacall(compute_morton_kernel) } + }; + + auto iter = dispatcher.find(min(0, algorithm)); + if (iter != dispatcher.end()) { + (iter->second)( + dim3(gridsize), dim3(blocksize), stream, + N, K, + primtransf_data, + code); + } +} + +void build_tree_cuda( + int N, int K, + int * sortedcode, + int * nodechildren, + int * nodeparent, + cudaStream_t stream) { + int count = N * (K + K - 1); + int nthreads = 512; + int nblocks = (count + nthreads - 1) / nthreads; + build_tree_kernel<<>>( + N, K, + sortedcode, + reinterpret_cast(nodechildren), + nodeparent); +} + +void compute_aabb_cuda( + int N, int K, + float * primpos, + float * primrot, + float * primscale, + int * sortedobjid, + int * nodechildren, + int * nodeparent, + float * nodeaabb, + int algorithm, + cudaStream_t stream) { + int * atom; + cudaMalloc(&atom, N * (K - 1) * 4); + cudaMemset(atom, 0, N * (K - 1) * 4); + + int count = N * K; + int blocksize = 512; + int gridsize = (count + blocksize - 1) / blocksize; + + std::shared_ptr primtransf_data; + primtransf_data = std::make_shared(PrimTransfSRT::Data{ + PrimTransfDataBase{}, + K, (float3*)primpos, nullptr, + K * 3, (float3*)primrot, nullptr, + K, (float3*)primscale, nullptr}); + + std::map, int*, int2*, int*, float3*, int*)>> dispatcher = { + { 0, make_cudacall(compute_aabb_kernel) } + }; + + auto iter = dispatcher.find(min(0, algorithm)); + if (iter != dispatcher.end()) { + (iter->second)( + dim3(gridsize), dim3(blocksize), stream, + N, K, + primtransf_data, + sortedobjid, + reinterpret_cast(nodechildren), + nodeparent, + reinterpret_cast(nodeaabb), + atom); + } + + cudaFree(atom); +} diff --git a/dva/mvp/extensions/mvpraymarch/cudadispatch.h b/dva/mvp/extensions/mvpraymarch/cudadispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..ed6e4f595268499bf3b4f4c58107231c141f9534 --- /dev/null +++ b/dva/mvp/extensions/mvpraymarch/cudadispatch.h @@ -0,0 +1,104 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#ifndef cudadispatch_h_ +#define cudadispatch_h_ + +#include +#include +#include + +template +struct get_base { + typedef T type; +}; + +template +struct get_base::value>::type> { + typedef std::shared_ptr type; +}; + +template struct is_shared_ptr : std::false_type {}; +template struct is_shared_ptr> : std::true_type {}; + +template +auto convert_shptr_impl2(std::shared_ptr t) { + return *static_cast(t.get()); +} + +template +auto convert_shptr_impl(T&& t, std::false_type) { + return convert_shptr_impl2(t); +} + +template +auto convert_shptr_impl(T&& t, std::true_type) { + return std::forward(t); +} + +template +auto convert_shptr(T&& t) { + return convert_shptr_impl(std::forward(t), std::is_same{}); +} + +template +struct cudacall { + struct functbase { + virtual ~functbase() {} + virtual void call(dim3, dim3, cudaStream_t, ArgsIn...) const = 0; + }; + + template + struct funct : public functbase { + std::function fn; + funct(void(*fn_)(ArgsOut...)) : fn(fn_) { } + void call(dim3 gridsize, dim3 blocksize, cudaStream_t stream, ArgsIn... args) const { + void (*const*kfunc)(ArgsOut...) = fn.template target(); + (*kfunc)<<>>( + std::forward(convert_shptr(std::forward(args)))...); + } + }; + + std::shared_ptr fn; + + template + cudacall(void(*fn_)(ArgsOut...)) : fn(std::make_shared>(fn_)) { } + + template + void call(dim3 gridsize, dim3 blocksize, cudaStream_t stream, ArgsTmp&&... args) const { + fn->call(gridsize, blocksize, stream, std::forward(args)...); + } +}; + +template +struct binder { + F f; T t; + template + auto operator()(Args&&... args) const + -> decltype(f(t, std::forward(args)...)) { + return f(t, std::forward(args)...); + } +}; + +template +binder::type + , typename std::decay::type> BindFirst(F&& f, T&& t) { + return { std::forward(f), std::forward(t) }; +} + +template +auto make_cudacall_(void(*fn)(ArgsOut...)) { + return BindFirst( + std::mem_fn(&cudacall::type...>::template call::type...>), + cudacall::type...>(fn)); +} + +template +std::function::type...)> make_cudacall(void(*fn)(ArgsOut...)) { + return std::function::type...)>(make_cudacall_(fn)); +} + +#endif diff --git a/dva/mvp/extensions/mvpraymarch/helper_math.h b/dva/mvp/extensions/mvpraymarch/helper_math.h new file mode 100644 index 0000000000000000000000000000000000000000..c9c07c3e74bbd1f469740f95c45a2eae49322e99 --- /dev/null +++ b/dva/mvp/extensions/mvpraymarch/helper_math.h @@ -0,0 +1,1453 @@ +/** + * Copyright 1993-2013 NVIDIA Corporation. All rights reserved. + * + * Please refer to the NVIDIA end user license agreement (EULA) associated + * with this source code for terms and conditions that govern your use of + * this software. Any use, reproduction, disclosure, or distribution of + * this software and related documentation outside the terms of the EULA + * is strictly prohibited. + * + */ + +/* + * This file implements common mathematical operations on vector types + * (float3, float4 etc.) since these are not provided as standard by CUDA. + * + * The syntax is modeled on the Cg standard library. + * + * This is part of the Helper library includes + * + * Thanks to Linh Hah for additions and fixes. + */ + +#ifndef HELPER_MATH_H +#define HELPER_MATH_H + +#include "cuda_runtime.h" + +typedef unsigned int uint; +typedef unsigned short ushort; + +#ifndef EXIT_WAIVED +#define EXIT_WAIVED 2 +#endif + +#ifndef __CUDACC__ +#include + +//////////////////////////////////////////////////////////////////////////////// +// host implementations of CUDA functions +//////////////////////////////////////////////////////////////////////////////// + +inline float fminf(float a, float b) +{ + return a < b ? a : b; +} + +inline float fmaxf(float a, float b) +{ + return a > b ? a : b; +} + +inline int max(int a, int b) +{ + return a > b ? a : b; +} + +inline int min(int a, int b) +{ + return a < b ? a : b; +} + +inline float rsqrtf(float x) +{ + return 1.0f / sqrtf(x); +} +#endif + +//////////////////////////////////////////////////////////////////////////////// +// constructors +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 make_float2(float s) +{ + return make_float2(s, s); +} +inline __host__ __device__ float2 make_float2(float3 a) +{ + return make_float2(a.x, a.y); +} +inline __host__ __device__ float2 make_float2(int2 a) +{ + return make_float2(float(a.x), float(a.y)); +} +inline __host__ __device__ float2 make_float2(uint2 a) +{ + return make_float2(float(a.x), float(a.y)); +} + +inline __host__ __device__ int2 make_int2(int s) +{ + return make_int2(s, s); +} +inline __host__ __device__ int2 make_int2(int3 a) +{ + return make_int2(a.x, a.y); +} +inline __host__ __device__ int2 make_int2(uint2 a) +{ + return make_int2(int(a.x), int(a.y)); +} +inline __host__ __device__ int2 make_int2(float2 a) +{ + return make_int2(int(a.x), int(a.y)); +} + +inline __host__ __device__ uint2 make_uint2(uint s) +{ + return make_uint2(s, s); +} +inline __host__ __device__ uint2 make_uint2(uint3 a) +{ + return make_uint2(a.x, a.y); +} +inline __host__ __device__ uint2 make_uint2(int2 a) +{ + return make_uint2(uint(a.x), uint(a.y)); +} + +inline __host__ __device__ float3 make_float3(float s) +{ + return make_float3(s, s, s); +} +inline __host__ __device__ float3 make_float3(float2 a) +{ + return make_float3(a.x, a.y, 0.0f); +} +inline __host__ __device__ float3 make_float3(float2 a, float s) +{ + return make_float3(a.x, a.y, s); +} +inline __host__ __device__ float3 make_float3(float4 a) +{ + return make_float3(a.x, a.y, a.z); +} +inline __host__ __device__ float3 make_float3(int3 a) +{ + return make_float3(float(a.x), float(a.y), float(a.z)); +} +inline __host__ __device__ float3 make_float3(uint3 a) +{ + return make_float3(float(a.x), float(a.y), float(a.z)); +} + +inline __host__ __device__ int3 make_int3(int s) +{ + return make_int3(s, s, s); +} +inline __host__ __device__ int3 make_int3(int2 a) +{ + return make_int3(a.x, a.y, 0); +} +inline __host__ __device__ int3 make_int3(int2 a, int s) +{ + return make_int3(a.x, a.y, s); +} +inline __host__ __device__ int3 make_int3(uint3 a) +{ + return make_int3(int(a.x), int(a.y), int(a.z)); +} +inline __host__ __device__ int3 make_int3(float3 a) +{ + return make_int3(int(a.x), int(a.y), int(a.z)); +} + +inline __host__ __device__ uint3 make_uint3(uint s) +{ + return make_uint3(s, s, s); +} +inline __host__ __device__ uint3 make_uint3(uint2 a) +{ + return make_uint3(a.x, a.y, 0); +} +inline __host__ __device__ uint3 make_uint3(uint2 a, uint s) +{ + return make_uint3(a.x, a.y, s); +} +inline __host__ __device__ uint3 make_uint3(uint4 a) +{ + return make_uint3(a.x, a.y, a.z); +} +inline __host__ __device__ uint3 make_uint3(int3 a) +{ + return make_uint3(uint(a.x), uint(a.y), uint(a.z)); +} + +inline __host__ __device__ float4 make_float4(float s) +{ + return make_float4(s, s, s, s); +} +inline __host__ __device__ float4 make_float4(float3 a) +{ + return make_float4(a.x, a.y, a.z, 0.0f); +} +inline __host__ __device__ float4 make_float4(float3 a, float w) +{ + return make_float4(a.x, a.y, a.z, w); +} +inline __host__ __device__ float4 make_float4(int4 a) +{ + return make_float4(float(a.x), float(a.y), float(a.z), float(a.w)); +} +inline __host__ __device__ float4 make_float4(uint4 a) +{ + return make_float4(float(a.x), float(a.y), float(a.z), float(a.w)); +} + +inline __host__ __device__ int4 make_int4(int s) +{ + return make_int4(s, s, s, s); +} +inline __host__ __device__ int4 make_int4(int3 a) +{ + return make_int4(a.x, a.y, a.z, 0); +} +inline __host__ __device__ int4 make_int4(int3 a, int w) +{ + return make_int4(a.x, a.y, a.z, w); +} +inline __host__ __device__ int4 make_int4(uint4 a) +{ + return make_int4(int(a.x), int(a.y), int(a.z), int(a.w)); +} +inline __host__ __device__ int4 make_int4(float4 a) +{ + return make_int4(int(a.x), int(a.y), int(a.z), int(a.w)); +} + + +inline __host__ __device__ uint4 make_uint4(uint s) +{ + return make_uint4(s, s, s, s); +} +inline __host__ __device__ uint4 make_uint4(uint3 a) +{ + return make_uint4(a.x, a.y, a.z, 0); +} +inline __host__ __device__ uint4 make_uint4(uint3 a, uint w) +{ + return make_uint4(a.x, a.y, a.z, w); +} +inline __host__ __device__ uint4 make_uint4(int4 a) +{ + return make_uint4(uint(a.x), uint(a.y), uint(a.z), uint(a.w)); +} + +//////////////////////////////////////////////////////////////////////////////// +// negate +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 operator-(float2 &a) +{ + return make_float2(-a.x, -a.y); +} +inline __host__ __device__ int2 operator-(int2 &a) +{ + return make_int2(-a.x, -a.y); +} +inline __host__ __device__ float3 operator-(float3 &a) +{ + return make_float3(-a.x, -a.y, -a.z); +} +inline __host__ __device__ int3 operator-(int3 &a) +{ + return make_int3(-a.x, -a.y, -a.z); +} +inline __host__ __device__ float4 operator-(float4 &a) +{ + return make_float4(-a.x, -a.y, -a.z, -a.w); +} +inline __host__ __device__ int4 operator-(int4 &a) +{ + return make_int4(-a.x, -a.y, -a.z, -a.w); +} + +//////////////////////////////////////////////////////////////////////////////// +// addition +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 operator+(float2 a, float2 b) +{ + return make_float2(a.x + b.x, a.y + b.y); +} +inline __host__ __device__ void operator+=(float2 &a, float2 b) +{ + a.x += b.x; + a.y += b.y; +} +inline __host__ __device__ float2 operator+(float2 a, float b) +{ + return make_float2(a.x + b, a.y + b); +} +inline __host__ __device__ float2 operator+(float b, float2 a) +{ + return make_float2(a.x + b, a.y + b); +} +inline __host__ __device__ void operator+=(float2 &a, float b) +{ + a.x += b; + a.y += b; +} + +inline __host__ __device__ int2 operator+(int2 a, int2 b) +{ + return make_int2(a.x + b.x, a.y + b.y); +} +inline __host__ __device__ void operator+=(int2 &a, int2 b) +{ + a.x += b.x; + a.y += b.y; +} +inline __host__ __device__ int2 operator+(int2 a, int b) +{ + return make_int2(a.x + b, a.y + b); +} +inline __host__ __device__ int2 operator+(int b, int2 a) +{ + return make_int2(a.x + b, a.y + b); +} +inline __host__ __device__ void operator+=(int2 &a, int b) +{ + a.x += b; + a.y += b; +} + +inline __host__ __device__ uint2 operator+(uint2 a, uint2 b) +{ + return make_uint2(a.x + b.x, a.y + b.y); +} +inline __host__ __device__ void operator+=(uint2 &a, uint2 b) +{ + a.x += b.x; + a.y += b.y; +} +inline __host__ __device__ uint2 operator+(uint2 a, uint b) +{ + return make_uint2(a.x + b, a.y + b); +} +inline __host__ __device__ uint2 operator+(uint b, uint2 a) +{ + return make_uint2(a.x + b, a.y + b); +} +inline __host__ __device__ void operator+=(uint2 &a, uint b) +{ + a.x += b; + a.y += b; +} + + +inline __host__ __device__ float3 operator+(float3 a, float3 b) +{ + return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); +} +inline __host__ __device__ void operator+=(float3 &a, float3 b) +{ + a.x += b.x; + a.y += b.y; + a.z += b.z; +} +inline __host__ __device__ float3 operator+(float3 a, float b) +{ + return make_float3(a.x + b, a.y + b, a.z + b); +} +inline __host__ __device__ void operator+=(float3 &a, float b) +{ + a.x += b; + a.y += b; + a.z += b; +} + +inline __host__ __device__ int3 operator+(int3 a, int3 b) +{ + return make_int3(a.x + b.x, a.y + b.y, a.z + b.z); +} +inline __host__ __device__ void operator+=(int3 &a, int3 b) +{ + a.x += b.x; + a.y += b.y; + a.z += b.z; +} +inline __host__ __device__ int3 operator+(int3 a, int b) +{ + return make_int3(a.x + b, a.y + b, a.z + b); +} +inline __host__ __device__ void operator+=(int3 &a, int b) +{ + a.x += b; + a.y += b; + a.z += b; +} + +inline __host__ __device__ uint3 operator+(uint3 a, uint3 b) +{ + return make_uint3(a.x + b.x, a.y + b.y, a.z + b.z); +} +inline __host__ __device__ void operator+=(uint3 &a, uint3 b) +{ + a.x += b.x; + a.y += b.y; + a.z += b.z; +} +inline __host__ __device__ uint3 operator+(uint3 a, uint b) +{ + return make_uint3(a.x + b, a.y + b, a.z + b); +} +inline __host__ __device__ void operator+=(uint3 &a, uint b) +{ + a.x += b; + a.y += b; + a.z += b; +} + +inline __host__ __device__ int3 operator+(int b, int3 a) +{ + return make_int3(a.x + b, a.y + b, a.z + b); +} +inline __host__ __device__ uint3 operator+(uint b, uint3 a) +{ + return make_uint3(a.x + b, a.y + b, a.z + b); +} +inline __host__ __device__ float3 operator+(float b, float3 a) +{ + return make_float3(a.x + b, a.y + b, a.z + b); +} + +inline __host__ __device__ float4 operator+(float4 a, float4 b) +{ + return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); +} +inline __host__ __device__ void operator+=(float4 &a, float4 b) +{ + a.x += b.x; + a.y += b.y; + a.z += b.z; + a.w += b.w; +} +inline __host__ __device__ float4 operator+(float4 a, float b) +{ + return make_float4(a.x + b, a.y + b, a.z + b, a.w + b); +} +inline __host__ __device__ float4 operator+(float b, float4 a) +{ + return make_float4(a.x + b, a.y + b, a.z + b, a.w + b); +} +inline __host__ __device__ void operator+=(float4 &a, float b) +{ + a.x += b; + a.y += b; + a.z += b; + a.w += b; +} + +inline __host__ __device__ int4 operator+(int4 a, int4 b) +{ + return make_int4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); +} +inline __host__ __device__ void operator+=(int4 &a, int4 b) +{ + a.x += b.x; + a.y += b.y; + a.z += b.z; + a.w += b.w; +} +inline __host__ __device__ int4 operator+(int4 a, int b) +{ + return make_int4(a.x + b, a.y + b, a.z + b, a.w + b); +} +inline __host__ __device__ int4 operator+(int b, int4 a) +{ + return make_int4(a.x + b, a.y + b, a.z + b, a.w + b); +} +inline __host__ __device__ void operator+=(int4 &a, int b) +{ + a.x += b; + a.y += b; + a.z += b; + a.w += b; +} + +inline __host__ __device__ uint4 operator+(uint4 a, uint4 b) +{ + return make_uint4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); +} +inline __host__ __device__ void operator+=(uint4 &a, uint4 b) +{ + a.x += b.x; + a.y += b.y; + a.z += b.z; + a.w += b.w; +} +inline __host__ __device__ uint4 operator+(uint4 a, uint b) +{ + return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b); +} +inline __host__ __device__ uint4 operator+(uint b, uint4 a) +{ + return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b); +} +inline __host__ __device__ void operator+=(uint4 &a, uint b) +{ + a.x += b; + a.y += b; + a.z += b; + a.w += b; +} + +//////////////////////////////////////////////////////////////////////////////// +// subtract +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 operator-(float2 a, float2 b) +{ + return make_float2(a.x - b.x, a.y - b.y); +} +inline __host__ __device__ void operator-=(float2 &a, float2 b) +{ + a.x -= b.x; + a.y -= b.y; +} +inline __host__ __device__ float2 operator-(float2 a, float b) +{ + return make_float2(a.x - b, a.y - b); +} +inline __host__ __device__ float2 operator-(float b, float2 a) +{ + return make_float2(b - a.x, b - a.y); +} +inline __host__ __device__ void operator-=(float2 &a, float b) +{ + a.x -= b; + a.y -= b; +} + +inline __host__ __device__ int2 operator-(int2 a, int2 b) +{ + return make_int2(a.x - b.x, a.y - b.y); +} +inline __host__ __device__ void operator-=(int2 &a, int2 b) +{ + a.x -= b.x; + a.y -= b.y; +} +inline __host__ __device__ int2 operator-(int2 a, int b) +{ + return make_int2(a.x - b, a.y - b); +} +inline __host__ __device__ int2 operator-(int b, int2 a) +{ + return make_int2(b - a.x, b - a.y); +} +inline __host__ __device__ void operator-=(int2 &a, int b) +{ + a.x -= b; + a.y -= b; +} + +inline __host__ __device__ uint2 operator-(uint2 a, uint2 b) +{ + return make_uint2(a.x - b.x, a.y - b.y); +} +inline __host__ __device__ void operator-=(uint2 &a, uint2 b) +{ + a.x -= b.x; + a.y -= b.y; +} +inline __host__ __device__ uint2 operator-(uint2 a, uint b) +{ + return make_uint2(a.x - b, a.y - b); +} +inline __host__ __device__ uint2 operator-(uint b, uint2 a) +{ + return make_uint2(b - a.x, b - a.y); +} +inline __host__ __device__ void operator-=(uint2 &a, uint b) +{ + a.x -= b; + a.y -= b; +} + +inline __host__ __device__ float3 operator-(float3 a, float3 b) +{ + return make_float3(a.x - b.x, a.y - b.y, a.z - b.z); +} +inline __host__ __device__ void operator-=(float3 &a, float3 b) +{ + a.x -= b.x; + a.y -= b.y; + a.z -= b.z; +} +inline __host__ __device__ float3 operator-(float3 a, float b) +{ + return make_float3(a.x - b, a.y - b, a.z - b); +} +inline __host__ __device__ float3 operator-(float b, float3 a) +{ + return make_float3(b - a.x, b - a.y, b - a.z); +} +inline __host__ __device__ void operator-=(float3 &a, float b) +{ + a.x -= b; + a.y -= b; + a.z -= b; +} + +inline __host__ __device__ int3 operator-(int3 a, int3 b) +{ + return make_int3(a.x - b.x, a.y - b.y, a.z - b.z); +} +inline __host__ __device__ void operator-=(int3 &a, int3 b) +{ + a.x -= b.x; + a.y -= b.y; + a.z -= b.z; +} +inline __host__ __device__ int3 operator-(int3 a, int b) +{ + return make_int3(a.x - b, a.y - b, a.z - b); +} +inline __host__ __device__ int3 operator-(int b, int3 a) +{ + return make_int3(b - a.x, b - a.y, b - a.z); +} +inline __host__ __device__ void operator-=(int3 &a, int b) +{ + a.x -= b; + a.y -= b; + a.z -= b; +} + +inline __host__ __device__ uint3 operator-(uint3 a, uint3 b) +{ + return make_uint3(a.x - b.x, a.y - b.y, a.z - b.z); +} +inline __host__ __device__ void operator-=(uint3 &a, uint3 b) +{ + a.x -= b.x; + a.y -= b.y; + a.z -= b.z; +} +inline __host__ __device__ uint3 operator-(uint3 a, uint b) +{ + return make_uint3(a.x - b, a.y - b, a.z - b); +} +inline __host__ __device__ uint3 operator-(uint b, uint3 a) +{ + return make_uint3(b - a.x, b - a.y, b - a.z); +} +inline __host__ __device__ void operator-=(uint3 &a, uint b) +{ + a.x -= b; + a.y -= b; + a.z -= b; +} + +inline __host__ __device__ float4 operator-(float4 a, float4 b) +{ + return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); +} +inline __host__ __device__ void operator-=(float4 &a, float4 b) +{ + a.x -= b.x; + a.y -= b.y; + a.z -= b.z; + a.w -= b.w; +} +inline __host__ __device__ float4 operator-(float4 a, float b) +{ + return make_float4(a.x - b, a.y - b, a.z - b, a.w - b); +} +inline __host__ __device__ void operator-=(float4 &a, float b) +{ + a.x -= b; + a.y -= b; + a.z -= b; + a.w -= b; +} + +inline __host__ __device__ int4 operator-(int4 a, int4 b) +{ + return make_int4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); +} +inline __host__ __device__ void operator-=(int4 &a, int4 b) +{ + a.x -= b.x; + a.y -= b.y; + a.z -= b.z; + a.w -= b.w; +} +inline __host__ __device__ int4 operator-(int4 a, int b) +{ + return make_int4(a.x - b, a.y - b, a.z - b, a.w - b); +} +inline __host__ __device__ int4 operator-(int b, int4 a) +{ + return make_int4(b - a.x, b - a.y, b - a.z, b - a.w); +} +inline __host__ __device__ void operator-=(int4 &a, int b) +{ + a.x -= b; + a.y -= b; + a.z -= b; + a.w -= b; +} + +inline __host__ __device__ uint4 operator-(uint4 a, uint4 b) +{ + return make_uint4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); +} +inline __host__ __device__ void operator-=(uint4 &a, uint4 b) +{ + a.x -= b.x; + a.y -= b.y; + a.z -= b.z; + a.w -= b.w; +} +inline __host__ __device__ uint4 operator-(uint4 a, uint b) +{ + return make_uint4(a.x - b, a.y - b, a.z - b, a.w - b); +} +inline __host__ __device__ uint4 operator-(uint b, uint4 a) +{ + return make_uint4(b - a.x, b - a.y, b - a.z, b - a.w); +} +inline __host__ __device__ void operator-=(uint4 &a, uint b) +{ + a.x -= b; + a.y -= b; + a.z -= b; + a.w -= b; +} + +//////////////////////////////////////////////////////////////////////////////// +// multiply +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 operator*(float2 a, float2 b) +{ + return make_float2(a.x * b.x, a.y * b.y); +} +inline __host__ __device__ void operator*=(float2 &a, float2 b) +{ + a.x *= b.x; + a.y *= b.y; +} +inline __host__ __device__ float2 operator*(float2 a, float b) +{ + return make_float2(a.x * b, a.y * b); +} +inline __host__ __device__ float2 operator*(float b, float2 a) +{ + return make_float2(b * a.x, b * a.y); +} +inline __host__ __device__ void operator*=(float2 &a, float b) +{ + a.x *= b; + a.y *= b; +} + +inline __host__ __device__ int2 operator*(int2 a, int2 b) +{ + return make_int2(a.x * b.x, a.y * b.y); +} +inline __host__ __device__ void operator*=(int2 &a, int2 b) +{ + a.x *= b.x; + a.y *= b.y; +} +inline __host__ __device__ int2 operator*(int2 a, int b) +{ + return make_int2(a.x * b, a.y * b); +} +inline __host__ __device__ int2 operator*(int b, int2 a) +{ + return make_int2(b * a.x, b * a.y); +} +inline __host__ __device__ void operator*=(int2 &a, int b) +{ + a.x *= b; + a.y *= b; +} + +inline __host__ __device__ uint2 operator*(uint2 a, uint2 b) +{ + return make_uint2(a.x * b.x, a.y * b.y); +} +inline __host__ __device__ void operator*=(uint2 &a, uint2 b) +{ + a.x *= b.x; + a.y *= b.y; +} +inline __host__ __device__ uint2 operator*(uint2 a, uint b) +{ + return make_uint2(a.x * b, a.y * b); +} +inline __host__ __device__ uint2 operator*(uint b, uint2 a) +{ + return make_uint2(b * a.x, b * a.y); +} +inline __host__ __device__ void operator*=(uint2 &a, uint b) +{ + a.x *= b; + a.y *= b; +} + +inline __host__ __device__ float3 operator*(float3 a, float3 b) +{ + return make_float3(a.x * b.x, a.y * b.y, a.z * b.z); +} +inline __host__ __device__ void operator*=(float3 &a, float3 b) +{ + a.x *= b.x; + a.y *= b.y; + a.z *= b.z; +} +inline __host__ __device__ float3 operator*(float3 a, float b) +{ + return make_float3(a.x * b, a.y * b, a.z * b); +} +inline __host__ __device__ float3 operator*(float b, float3 a) +{ + return make_float3(b * a.x, b * a.y, b * a.z); +} +inline __host__ __device__ void operator*=(float3 &a, float b) +{ + a.x *= b; + a.y *= b; + a.z *= b; +} + +inline __host__ __device__ int3 operator*(int3 a, int3 b) +{ + return make_int3(a.x * b.x, a.y * b.y, a.z * b.z); +} +inline __host__ __device__ void operator*=(int3 &a, int3 b) +{ + a.x *= b.x; + a.y *= b.y; + a.z *= b.z; +} +inline __host__ __device__ int3 operator*(int3 a, int b) +{ + return make_int3(a.x * b, a.y * b, a.z * b); +} +inline __host__ __device__ int3 operator*(int b, int3 a) +{ + return make_int3(b * a.x, b * a.y, b * a.z); +} +inline __host__ __device__ void operator*=(int3 &a, int b) +{ + a.x *= b; + a.y *= b; + a.z *= b; +} + +inline __host__ __device__ uint3 operator*(uint3 a, uint3 b) +{ + return make_uint3(a.x * b.x, a.y * b.y, a.z * b.z); +} +inline __host__ __device__ void operator*=(uint3 &a, uint3 b) +{ + a.x *= b.x; + a.y *= b.y; + a.z *= b.z; +} +inline __host__ __device__ uint3 operator*(uint3 a, uint b) +{ + return make_uint3(a.x * b, a.y * b, a.z * b); +} +inline __host__ __device__ uint3 operator*(uint b, uint3 a) +{ + return make_uint3(b * a.x, b * a.y, b * a.z); +} +inline __host__ __device__ void operator*=(uint3 &a, uint b) +{ + a.x *= b; + a.y *= b; + a.z *= b; +} + +inline __host__ __device__ float4 operator*(float4 a, float4 b) +{ + return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); +} +inline __host__ __device__ void operator*=(float4 &a, float4 b) +{ + a.x *= b.x; + a.y *= b.y; + a.z *= b.z; + a.w *= b.w; +} +inline __host__ __device__ float4 operator*(float4 a, float b) +{ + return make_float4(a.x * b, a.y * b, a.z * b, a.w * b); +} +inline __host__ __device__ float4 operator*(float b, float4 a) +{ + return make_float4(b * a.x, b * a.y, b * a.z, b * a.w); +} +inline __host__ __device__ void operator*=(float4 &a, float b) +{ + a.x *= b; + a.y *= b; + a.z *= b; + a.w *= b; +} + +inline __host__ __device__ int4 operator*(int4 a, int4 b) +{ + return make_int4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); +} +inline __host__ __device__ void operator*=(int4 &a, int4 b) +{ + a.x *= b.x; + a.y *= b.y; + a.z *= b.z; + a.w *= b.w; +} +inline __host__ __device__ int4 operator*(int4 a, int b) +{ + return make_int4(a.x * b, a.y * b, a.z * b, a.w * b); +} +inline __host__ __device__ int4 operator*(int b, int4 a) +{ + return make_int4(b * a.x, b * a.y, b * a.z, b * a.w); +} +inline __host__ __device__ void operator*=(int4 &a, int b) +{ + a.x *= b; + a.y *= b; + a.z *= b; + a.w *= b; +} + +inline __host__ __device__ uint4 operator*(uint4 a, uint4 b) +{ + return make_uint4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); +} +inline __host__ __device__ void operator*=(uint4 &a, uint4 b) +{ + a.x *= b.x; + a.y *= b.y; + a.z *= b.z; + a.w *= b.w; +} +inline __host__ __device__ uint4 operator*(uint4 a, uint b) +{ + return make_uint4(a.x * b, a.y * b, a.z * b, a.w * b); +} +inline __host__ __device__ uint4 operator*(uint b, uint4 a) +{ + return make_uint4(b * a.x, b * a.y, b * a.z, b * a.w); +} +inline __host__ __device__ void operator*=(uint4 &a, uint b) +{ + a.x *= b; + a.y *= b; + a.z *= b; + a.w *= b; +} + +//////////////////////////////////////////////////////////////////////////////// +// divide +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 operator/(float2 a, float2 b) +{ + return make_float2(a.x / b.x, a.y / b.y); +} +inline __host__ __device__ void operator/=(float2 &a, float2 b) +{ + a.x /= b.x; + a.y /= b.y; +} +inline __host__ __device__ float2 operator/(float2 a, float b) +{ + return make_float2(a.x / b, a.y / b); +} +inline __host__ __device__ void operator/=(float2 &a, float b) +{ + a.x /= b; + a.y /= b; +} +inline __host__ __device__ float2 operator/(float b, float2 a) +{ + return make_float2(b / a.x, b / a.y); +} + +inline __host__ __device__ float3 operator/(float3 a, float3 b) +{ + return make_float3(a.x / b.x, a.y / b.y, a.z / b.z); +} +inline __host__ __device__ void operator/=(float3 &a, float3 b) +{ + a.x /= b.x; + a.y /= b.y; + a.z /= b.z; +} +inline __host__ __device__ float3 operator/(float3 a, float b) +{ + return make_float3(a.x / b, a.y / b, a.z / b); +} +inline __host__ __device__ void operator/=(float3 &a, float b) +{ + a.x /= b; + a.y /= b; + a.z /= b; +} +inline __host__ __device__ float3 operator/(float b, float3 a) +{ + return make_float3(b / a.x, b / a.y, b / a.z); +} + +inline __host__ __device__ float4 operator/(float4 a, float4 b) +{ + return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w); +} +inline __host__ __device__ void operator/=(float4 &a, float4 b) +{ + a.x /= b.x; + a.y /= b.y; + a.z /= b.z; + a.w /= b.w; +} +inline __host__ __device__ float4 operator/(float4 a, float b) +{ + return make_float4(a.x / b, a.y / b, a.z / b, a.w / b); +} +inline __host__ __device__ void operator/=(float4 &a, float b) +{ + a.x /= b; + a.y /= b; + a.z /= b; + a.w /= b; +} +inline __host__ __device__ float4 operator/(float b, float4 a) +{ + return make_float4(b / a.x, b / a.y, b / a.z, b / a.w); +} + +//////////////////////////////////////////////////////////////////////////////// +// min +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 fminf(float2 a, float2 b) +{ + return make_float2(fminf(a.x,b.x), fminf(a.y,b.y)); +} +inline __host__ __device__ float3 fminf(float3 a, float3 b) +{ + return make_float3(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z)); +} +inline __host__ __device__ float4 fminf(float4 a, float4 b) +{ + return make_float4(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z), fminf(a.w,b.w)); +} + +inline __host__ __device__ int2 min(int2 a, int2 b) +{ + return make_int2(min(a.x,b.x), min(a.y,b.y)); +} +inline __host__ __device__ int3 min(int3 a, int3 b) +{ + return make_int3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z)); +} +inline __host__ __device__ int4 min(int4 a, int4 b) +{ + return make_int4(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z), min(a.w,b.w)); +} + +inline __host__ __device__ uint2 min(uint2 a, uint2 b) +{ + return make_uint2(min(a.x,b.x), min(a.y,b.y)); +} +inline __host__ __device__ uint3 min(uint3 a, uint3 b) +{ + return make_uint3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z)); +} +inline __host__ __device__ uint4 min(uint4 a, uint4 b) +{ + return make_uint4(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z), min(a.w,b.w)); +} + +//////////////////////////////////////////////////////////////////////////////// +// max +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 fmaxf(float2 a, float2 b) +{ + return make_float2(fmaxf(a.x,b.x), fmaxf(a.y,b.y)); +} +inline __host__ __device__ float3 fmaxf(float3 a, float3 b) +{ + return make_float3(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z)); +} +inline __host__ __device__ float4 fmaxf(float4 a, float4 b) +{ + return make_float4(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z), fmaxf(a.w,b.w)); +} + +inline __host__ __device__ int2 max(int2 a, int2 b) +{ + return make_int2(max(a.x,b.x), max(a.y,b.y)); +} +inline __host__ __device__ int3 max(int3 a, int3 b) +{ + return make_int3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z)); +} +inline __host__ __device__ int4 max(int4 a, int4 b) +{ + return make_int4(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z), max(a.w,b.w)); +} + +inline __host__ __device__ uint2 max(uint2 a, uint2 b) +{ + return make_uint2(max(a.x,b.x), max(a.y,b.y)); +} +inline __host__ __device__ uint3 max(uint3 a, uint3 b) +{ + return make_uint3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z)); +} +inline __host__ __device__ uint4 max(uint4 a, uint4 b) +{ + return make_uint4(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z), max(a.w,b.w)); +} + +//////////////////////////////////////////////////////////////////////////////// +// lerp +// - linear interpolation between a and b, based on value t in [0, 1] range +//////////////////////////////////////////////////////////////////////////////// + +inline __device__ __host__ float lerp(float a, float b, float t) +{ + return a + t*(b-a); +} +inline __device__ __host__ float2 lerp(float2 a, float2 b, float t) +{ + return a + t*(b-a); +} +inline __device__ __host__ float3 lerp(float3 a, float3 b, float t) +{ + return a + t*(b-a); +} +inline __device__ __host__ float4 lerp(float4 a, float4 b, float t) +{ + return a + t*(b-a); +} + +//////////////////////////////////////////////////////////////////////////////// +// clamp +// - clamp the value v to be in the range [a, b] +//////////////////////////////////////////////////////////////////////////////// + +inline __device__ __host__ float clamp(float f, float a, float b) +{ + return fmaxf(a, fminf(f, b)); +} +inline __device__ __host__ int clamp(int f, int a, int b) +{ + return max(a, min(f, b)); +} +inline __device__ __host__ uint clamp(uint f, uint a, uint b) +{ + return max(a, min(f, b)); +} + +inline __device__ __host__ float2 clamp(float2 v, float a, float b) +{ + return make_float2(clamp(v.x, a, b), clamp(v.y, a, b)); +} +inline __device__ __host__ float2 clamp(float2 v, float2 a, float2 b) +{ + return make_float2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y)); +} +inline __device__ __host__ float3 clamp(float3 v, float a, float b) +{ + return make_float3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b)); +} +inline __device__ __host__ float3 clamp(float3 v, float3 a, float3 b) +{ + return make_float3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z)); +} +inline __device__ __host__ float4 clamp(float4 v, float a, float b) +{ + return make_float4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b)); +} +inline __device__ __host__ float4 clamp(float4 v, float4 a, float4 b) +{ + return make_float4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w)); +} + +inline __device__ __host__ int2 clamp(int2 v, int a, int b) +{ + return make_int2(clamp(v.x, a, b), clamp(v.y, a, b)); +} +inline __device__ __host__ int2 clamp(int2 v, int2 a, int2 b) +{ + return make_int2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y)); +} +inline __device__ __host__ int3 clamp(int3 v, int a, int b) +{ + return make_int3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b)); +} +inline __device__ __host__ int3 clamp(int3 v, int3 a, int3 b) +{ + return make_int3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z)); +} +inline __device__ __host__ int4 clamp(int4 v, int a, int b) +{ + return make_int4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b)); +} +inline __device__ __host__ int4 clamp(int4 v, int4 a, int4 b) +{ + return make_int4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w)); +} + +inline __device__ __host__ uint2 clamp(uint2 v, uint a, uint b) +{ + return make_uint2(clamp(v.x, a, b), clamp(v.y, a, b)); +} +inline __device__ __host__ uint2 clamp(uint2 v, uint2 a, uint2 b) +{ + return make_uint2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y)); +} +inline __device__ __host__ uint3 clamp(uint3 v, uint a, uint b) +{ + return make_uint3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b)); +} +inline __device__ __host__ uint3 clamp(uint3 v, uint3 a, uint3 b) +{ + return make_uint3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z)); +} +inline __device__ __host__ uint4 clamp(uint4 v, uint a, uint b) +{ + return make_uint4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b)); +} +inline __device__ __host__ uint4 clamp(uint4 v, uint4 a, uint4 b) +{ + return make_uint4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w)); +} + +//////////////////////////////////////////////////////////////////////////////// +// dot product +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float dot(float2 a, float2 b) +{ + return a.x * b.x + a.y * b.y; +} +inline __host__ __device__ float dot(float3 a, float3 b) +{ + return a.x * b.x + a.y * b.y + a.z * b.z; +} +inline __host__ __device__ float dot(float4 a, float4 b) +{ + return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w; +} + +inline __host__ __device__ int dot(int2 a, int2 b) +{ + return a.x * b.x + a.y * b.y; +} +inline __host__ __device__ int dot(int3 a, int3 b) +{ + return a.x * b.x + a.y * b.y + a.z * b.z; +} +inline __host__ __device__ int dot(int4 a, int4 b) +{ + return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w; +} + +inline __host__ __device__ uint dot(uint2 a, uint2 b) +{ + return a.x * b.x + a.y * b.y; +} +inline __host__ __device__ uint dot(uint3 a, uint3 b) +{ + return a.x * b.x + a.y * b.y + a.z * b.z; +} +inline __host__ __device__ uint dot(uint4 a, uint4 b) +{ + return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w; +} + +//////////////////////////////////////////////////////////////////////////////// +// length +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float length(float2 v) +{ + return sqrtf(dot(v, v)); +} +inline __host__ __device__ float length(float3 v) +{ + return sqrtf(dot(v, v)); +} +inline __host__ __device__ float length(float4 v) +{ + return sqrtf(dot(v, v)); +} + +//////////////////////////////////////////////////////////////////////////////// +// normalize +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 normalize(float2 v) +{ + float invLen = rsqrtf(dot(v, v)); + return v * invLen; +} +inline __host__ __device__ float3 normalize(float3 v) +{ + float invLen = rsqrtf(dot(v, v)); + return v * invLen; +} +inline __host__ __device__ float4 normalize(float4 v) +{ + float invLen = rsqrtf(dot(v, v)); + return v * invLen; +} + +//////////////////////////////////////////////////////////////////////////////// +// floor +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 floorf(float2 v) +{ + return make_float2(floorf(v.x), floorf(v.y)); +} +inline __host__ __device__ float3 floorf(float3 v) +{ + return make_float3(floorf(v.x), floorf(v.y), floorf(v.z)); +} +inline __host__ __device__ float4 floorf(float4 v) +{ + return make_float4(floorf(v.x), floorf(v.y), floorf(v.z), floorf(v.w)); +} + +//////////////////////////////////////////////////////////////////////////////// +// frac - returns the fractional portion of a scalar or each vector component +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float fracf(float v) +{ + return v - floorf(v); +} +inline __host__ __device__ float2 fracf(float2 v) +{ + return make_float2(fracf(v.x), fracf(v.y)); +} +inline __host__ __device__ float3 fracf(float3 v) +{ + return make_float3(fracf(v.x), fracf(v.y), fracf(v.z)); +} +inline __host__ __device__ float4 fracf(float4 v) +{ + return make_float4(fracf(v.x), fracf(v.y), fracf(v.z), fracf(v.w)); +} + +//////////////////////////////////////////////////////////////////////////////// +// fmod +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 fmodf(float2 a, float2 b) +{ + return make_float2(fmodf(a.x, b.x), fmodf(a.y, b.y)); +} +inline __host__ __device__ float3 fmodf(float3 a, float3 b) +{ + return make_float3(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z)); +} +inline __host__ __device__ float4 fmodf(float4 a, float4 b) +{ + return make_float4(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z), fmodf(a.w, b.w)); +} + +//////////////////////////////////////////////////////////////////////////////// +// absolute value +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 fabs(float2 v) +{ + return make_float2(fabs(v.x), fabs(v.y)); +} +inline __host__ __device__ float3 fabs(float3 v) +{ + return make_float3(fabs(v.x), fabs(v.y), fabs(v.z)); +} +inline __host__ __device__ float4 fabs(float4 v) +{ + return make_float4(fabs(v.x), fabs(v.y), fabs(v.z), fabs(v.w)); +} + +inline __host__ __device__ int2 abs(int2 v) +{ + return make_int2(abs(v.x), abs(v.y)); +} +inline __host__ __device__ int3 abs(int3 v) +{ + return make_int3(abs(v.x), abs(v.y), abs(v.z)); +} +inline __host__ __device__ int4 abs(int4 v) +{ + return make_int4(abs(v.x), abs(v.y), abs(v.z), abs(v.w)); +} + +//////////////////////////////////////////////////////////////////////////////// +// reflect +// - returns reflection of incident ray I around surface normal N +// - N should be normalized, reflected vector's length is equal to length of I +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float3 reflect(float3 i, float3 n) +{ + return i - 2.0f * n * dot(n,i); +} + +//////////////////////////////////////////////////////////////////////////////// +// cross product +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float3 cross(float3 a, float3 b) +{ + return make_float3(a.y*b.z - a.z*b.y, a.z*b.x - a.x*b.z, a.x*b.y - a.y*b.x); +} + +//////////////////////////////////////////////////////////////////////////////// +// smoothstep +// - returns 0 if x < a +// - returns 1 if x > b +// - otherwise returns smooth interpolation between 0 and 1 based on x +//////////////////////////////////////////////////////////////////////////////// + +inline __device__ __host__ float smoothstep(float a, float b, float x) +{ + float y = clamp((x - a) / (b - a), 0.0f, 1.0f); + return (y*y*(3.0f - (2.0f*y))); +} +inline __device__ __host__ float2 smoothstep(float2 a, float2 b, float2 x) +{ + float2 y = clamp((x - a) / (b - a), 0.0f, 1.0f); + return (y*y*(make_float2(3.0f) - (make_float2(2.0f)*y))); +} +inline __device__ __host__ float3 smoothstep(float3 a, float3 b, float3 x) +{ + float3 y = clamp((x - a) / (b - a), 0.0f, 1.0f); + return (y*y*(make_float3(3.0f) - (make_float3(2.0f)*y))); +} +inline __device__ __host__ float4 smoothstep(float4 a, float4 b, float4 x) +{ + float4 y = clamp((x - a) / (b - a), 0.0f, 1.0f); + return (y*y*(make_float4(3.0f) - (make_float4(2.0f)*y))); +} + +#endif diff --git a/dva/mvp/extensions/mvpraymarch/makefile b/dva/mvp/extensions/mvpraymarch/makefile new file mode 100644 index 0000000000000000000000000000000000000000..4a1f97a7a9c9320562641ad94b7ada28ef5c2777 --- /dev/null +++ b/dva/mvp/extensions/mvpraymarch/makefile @@ -0,0 +1,2 @@ +all: + python setup.py build_ext --inplace diff --git a/dva/mvp/extensions/mvpraymarch/mvpraymarch.cpp b/dva/mvp/extensions/mvpraymarch/mvpraymarch.cpp new file mode 100644 index 0000000000000000000000000000000000000000..909e85ac8824360acc08e2c64e9a78262e710bb5 --- /dev/null +++ b/dva/mvp/extensions/mvpraymarch/mvpraymarch.cpp @@ -0,0 +1,405 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +#include + +void compute_morton_cuda( + int N, int K, + float * primpos, + int * code, + int algorithm, + cudaStream_t stream); + +void build_tree_cuda( + int N, int K, + int * sortedcode, + int * nodechildren, + int * nodeparent, + cudaStream_t stream); + +void compute_aabb_cuda( + int N, int K, + float * primpos, + float * primrot, + float * primscale, + int * sortedobjid, + int * nodechildren, + int * nodeparent, + float * nodeaabb, + int algorithm, + cudaStream_t stream); + +void raymarch_forward_cuda( + int N, int H, int W, int K, + float * rayposim, + float * raydirim, + float stepsize, + float * tminmaxim, + + int * sortedobjid, + int * nodechildren, + float * nodeaabb, + + float * primpos, + float * primrot, + float * primscale, + + int TD, int TH, int TW, + float * tplate, + int WD, int WH, int WW, + float * warp, + + float * rayrgbaim, + float * raysatim, + int * raytermim, + + int algorithm, bool sortboxes, int maxhitboxes, bool synchitboxes, + bool chlast, float fadescale, float fadeexp, int accum, float termthresh, + int griddim, int blocksizex, int blocksizey, + cudaStream_t stream); + +void raymarch_backward_cuda( + int N, int H, int W, int K, + float * rayposim, + float * raydirim, + float stepsize, + float * tminmaxim, + + int * sortedobjid, + int * nodechildren, + float * nodeaabb, + + float * primpos, + float * grad_primpos, + float * primrot, + float * grad_primrot, + float * primscale, + float * grad_primscale, + + int TD, int TH, int TW, + float * tplate, + float * grad_tplate, + int WD, int WH, int WW, + float * warp, + float * grad_warp, + + float * rayrgbaim, + float * grad_rayrgba, + float * raysatim, + int * raytermim, + + int algorithm, bool sortboxes, int maxhitboxes, bool synchitboxes, + bool chlast, float fadescale, float fadeexp, int accum, float termthresh, + int griddim, int blocksizex, int blocksizey, + cudaStream_t stream); + +#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA((x)); CHECK_CONTIGUOUS((x)) + +std::vector compute_morton( + torch::Tensor primpos, + torch::Tensor code, + int algorithm) { + CHECK_INPUT(primpos); + CHECK_INPUT(code); + + int N = primpos.size(0); + int K = primpos.size(1); + + compute_morton_cuda( + N, K, + reinterpret_cast(primpos.data_ptr()), + reinterpret_cast(code.data_ptr()), + algorithm, + 0); + + return {}; +} + +std::vector build_tree( + torch::Tensor sortedcode, + torch::Tensor nodechildren, + torch::Tensor nodeparent) { + CHECK_INPUT(sortedcode); + CHECK_INPUT(nodechildren); + CHECK_INPUT(nodeparent); + + int N = sortedcode.size(0); + int K = sortedcode.size(1); + + build_tree_cuda(N, K, + reinterpret_cast(sortedcode.data_ptr()), + reinterpret_cast(nodechildren.data_ptr()), + reinterpret_cast(nodeparent.data_ptr()), + 0); + + return {}; +} + +std::vector compute_aabb( + torch::Tensor primpos, + torch::optional primrot, + torch::optional primscale, + torch::Tensor sortedobjid, + torch::Tensor nodechildren, + torch::Tensor nodeparent, + torch::Tensor nodeaabb, + int algorithm) { + CHECK_INPUT(sortedobjid); + CHECK_INPUT(primpos); + if (primrot) { CHECK_INPUT(*primrot); } + if (primscale) { CHECK_INPUT(*primscale); } + CHECK_INPUT(nodechildren); + CHECK_INPUT(nodeparent); + CHECK_INPUT(nodeaabb); + + int N = primpos.size(0); + int K = primpos.size(1); + + compute_aabb_cuda(N, K, + reinterpret_cast(primpos.data_ptr()), + primrot ? reinterpret_cast(primrot->data_ptr()) : nullptr, + primscale ? reinterpret_cast(primscale->data_ptr()) : nullptr, + reinterpret_cast(sortedobjid.data_ptr()), + reinterpret_cast(nodechildren.data_ptr()), + reinterpret_cast(nodeparent.data_ptr()), + reinterpret_cast(nodeaabb.data_ptr()), + algorithm, + 0); + + return {}; +} + +std::vector raymarch_forward( + torch::Tensor rayposim, + torch::Tensor raydirim, + float stepsize, + torch::Tensor tminmaxim, + + torch::optional sortedobjid, + torch::optional nodechildren, + torch::optional nodeaabb, + + torch::Tensor primpos, + torch::optional primrot, + torch::optional primscale, + + torch::Tensor tplate, + torch::optional warp, + + torch::Tensor rayrgbaim, + torch::optional raysatim, + torch::optional raytermim, + + int algorithm=0, + bool sortboxes=true, + int maxhitboxes=512, + bool synchitboxes=false, + bool chlast=false, + float fadescale=8.f, + float fadeexp=8.f, + int accum=0, + float termthresh=0.f, + int griddim=3, + int blocksizex=8, + int blocksizey=16) { + CHECK_INPUT(rayposim); + CHECK_INPUT(raydirim); + CHECK_INPUT(tminmaxim); + if (sortedobjid) { CHECK_INPUT(*sortedobjid); } + if (nodechildren) { CHECK_INPUT(*nodechildren); } + if (nodeaabb) { CHECK_INPUT(*nodeaabb); } + CHECK_INPUT(tplate); + if (warp) { CHECK_INPUT(*warp); } + CHECK_INPUT(primpos); + if (primrot) { CHECK_INPUT(*primrot); } + if (primscale) { CHECK_INPUT(*primscale); } + CHECK_INPUT(rayrgbaim); + if (raysatim) { CHECK_INPUT(*raysatim); } + if (raytermim) { CHECK_INPUT(*raytermim); } + + int N = rayposim.size(0); + int H = rayposim.size(1); + int W = rayposim.size(2); + int K = primpos.size(1); + + int TD, TH, TW; + if (chlast) { + TD = tplate.size(2); TH = tplate.size(3); TW = tplate.size(4); + } else { + TD = tplate.size(3); TH = tplate.size(4); TW = tplate.size(5); + } + + int WD = 0, WH = 0, WW = 0; + if (warp) { + if (chlast) { + WD = warp->size(2); WH = warp->size(3); WW = warp->size(4); + } else { + WD = warp->size(3); WH = warp->size(4); WW = warp->size(5); + } + } + + raymarch_forward_cuda(N, H, W, K, + reinterpret_cast(rayposim.data_ptr()), + reinterpret_cast(raydirim.data_ptr()), + stepsize, + reinterpret_cast(tminmaxim.data_ptr()), + sortedobjid ? reinterpret_cast(sortedobjid->data_ptr()) : nullptr, + nodechildren ? reinterpret_cast(nodechildren->data_ptr()) : nullptr, + nodeaabb ? reinterpret_cast(nodeaabb->data_ptr()) : nullptr, + + // prim transforms + reinterpret_cast(primpos.data_ptr()), + primrot ? reinterpret_cast(primrot->data_ptr()) : nullptr, + primscale ? reinterpret_cast(primscale->data_ptr()) : nullptr, + + // prim sampler + TD, TH, TW, + reinterpret_cast(tplate.data_ptr()), + WD, WH, WW, + warp ? reinterpret_cast(warp->data_ptr()) : nullptr, + + // prim accumulator + reinterpret_cast(rayrgbaim.data_ptr()), + raysatim ? reinterpret_cast(raysatim->data_ptr()) : nullptr, + raytermim ? reinterpret_cast(raytermim->data_ptr()) : nullptr, + + // options + algorithm, sortboxes, maxhitboxes, synchitboxes, chlast, fadescale, fadeexp, accum, termthresh, + griddim, blocksizex, blocksizey, + 0); + + return {}; +} + +std::vector raymarch_backward( + torch::Tensor rayposim, + torch::Tensor raydirim, + float stepsize, + torch::Tensor tminmaxim, + + torch::optional sortedobjid, + torch::optional nodechildren, + torch::optional nodeaabb, + + torch::Tensor primpos, + torch::Tensor grad_primpos, + torch::optional primrot, + torch::optional grad_primrot, + torch::optional primscale, + torch::optional grad_primscale, + + torch::Tensor tplate, + torch::Tensor grad_tplate, + torch::optional warp, + torch::optional grad_warp, + + torch::Tensor rayrgbaim, + torch::Tensor grad_rayrgba, + torch::optional raysatim, + torch::optional raytermim, + + int algorithm=0, + bool sortboxes=true, + int maxhitboxes=512, + bool synchitboxes=false, + bool chlast=false, + float fadescale=8.f, + float fadeexp=8.f, + int accum=0, + float termthresh=0.f, + int griddim=3, + int blocksizex=8, + int blocksizey=16) { + CHECK_INPUT(rayposim); + CHECK_INPUT(raydirim); + CHECK_INPUT(tminmaxim); + if (sortedobjid) { CHECK_INPUT(*sortedobjid); } + if (nodechildren) { CHECK_INPUT(*nodechildren); } + if (nodeaabb) { CHECK_INPUT(*nodeaabb); } + CHECK_INPUT(tplate); + if (warp) { CHECK_INPUT(*warp); } + CHECK_INPUT(primpos); + if (primrot) { CHECK_INPUT(*primrot); } + if (primscale) { CHECK_INPUT(*primscale); } + CHECK_INPUT(rayrgbaim); + if (raysatim) { CHECK_INPUT(*raysatim); } + if (raytermim) { CHECK_INPUT(*raytermim); } + CHECK_INPUT(grad_rayrgba); + CHECK_INPUT(grad_tplate); + if (grad_warp) { CHECK_INPUT(*grad_warp); } + CHECK_INPUT(grad_primpos); + if (grad_primrot) { CHECK_INPUT(*grad_primrot); } + if (grad_primscale) { CHECK_INPUT(*grad_primscale); } + + int N = rayposim.size(0); + int H = rayposim.size(1); + int W = rayposim.size(2); + int K = primpos.size(1); + + int TD, TH, TW; + if (chlast) { + TD = tplate.size(2); TH = tplate.size(3); TW = tplate.size(4); + } else { + TD = tplate.size(3); TH = tplate.size(4); TW = tplate.size(5); + } + + int WD = 0, WH = 0, WW = 0; + if (warp) { + if (chlast) { + WD = warp->size(2); WH = warp->size(3); WW = warp->size(4); + } else { + WD = warp->size(3); WH = warp->size(4); WW = warp->size(5); + } + } + + raymarch_backward_cuda(N, H, W, K, + reinterpret_cast(rayposim.data_ptr()), + reinterpret_cast(raydirim.data_ptr()), + stepsize, + reinterpret_cast(tminmaxim.data_ptr()), + sortedobjid ? reinterpret_cast(sortedobjid->data_ptr()) : nullptr, + nodechildren ? reinterpret_cast(nodechildren->data_ptr()) : nullptr, + nodeaabb ? reinterpret_cast(nodeaabb->data_ptr()) : nullptr, + + reinterpret_cast(primpos.data_ptr()), + reinterpret_cast(grad_primpos.data_ptr()), + primrot ? reinterpret_cast(primrot->data_ptr()) : nullptr, + grad_primrot ? reinterpret_cast(grad_primrot->data_ptr()) : nullptr, + primscale ? reinterpret_cast(primscale->data_ptr()) : nullptr, + grad_primscale ? reinterpret_cast(grad_primscale->data_ptr()) : nullptr, + + TD, TH, TW, + reinterpret_cast(tplate.data_ptr()), + reinterpret_cast(grad_tplate.data_ptr()), + WD, WH, WW, + warp ? reinterpret_cast(warp->data_ptr()) : nullptr, + grad_warp ? reinterpret_cast(grad_warp->data_ptr()) : nullptr, + + reinterpret_cast(rayrgbaim.data_ptr()), + reinterpret_cast(grad_rayrgba.data_ptr()), + raysatim ? reinterpret_cast(raysatim->data_ptr()) : nullptr, + raytermim ? reinterpret_cast(raytermim->data_ptr()) : nullptr, + + algorithm, sortboxes, maxhitboxes, synchitboxes, chlast, fadescale, fadeexp, accum, termthresh, + griddim, blocksizex, blocksizey, + 0); + + return {}; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("compute_morton", &compute_morton, "compute morton codes (CUDA)"); + m.def("build_tree", &build_tree, "build BVH tree (CUDA)"); + m.def("compute_aabb", &compute_aabb, "compute AABB sizes (CUDA)"); + + m.def("raymarch_forward", &raymarch_forward, "raymarch forward (CUDA)"); + m.def("raymarch_backward", &raymarch_backward, "raymarch backward (CUDA)"); +} diff --git a/dva/mvp/extensions/mvpraymarch/mvpraymarch.py b/dva/mvp/extensions/mvpraymarch/mvpraymarch.py new file mode 100644 index 0000000000000000000000000000000000000000..da2e047d286736c7cd3d8c182862e5de1e8e986b --- /dev/null +++ b/dva/mvp/extensions/mvpraymarch/mvpraymarch.py @@ -0,0 +1,559 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import time + +import torch +import torch.nn as nn +from torch.autograd import Function +import torch.nn.functional as F + +try: + from . import mvpraymarchlib +except: + import mvpraymarchlib + +def build_accel(primtransfin, algo, fixedorder=False): + """build bvh structure given primitive centers and sizes + + Parameters: + ---------- + primtransfin : tuple[tensor, tensor, tensor] + primitive transform tensors + algo : int + raymarching algorithm + fixedorder : optional[str] + True means the bvh builder will not reorder primitives and will + use a trivial tree structure. Likely to be slow for arbitrary + configurations of primitives. + + """ + primpos, primrot, primscale = primtransfin + + N = primpos.size(0) + K = primpos.size(1) + + dev = primpos.device + + # compute and sort morton codes + if fixedorder: + sortedobjid = (torch.arange(N*K, dtype=torch.int32, device=dev) % K).view(N, K) + else: + cmax = primpos.max(dim=1, keepdim=True)[0] + cmin = primpos.min(dim=1, keepdim=True)[0] + + centers_norm = (primpos - cmin) / (cmax - cmin).clamp(min=1e-8) + + mortoncode = torch.empty((N, K), dtype=torch.int32, device=dev) + mvpraymarchlib.compute_morton(centers_norm, mortoncode, algo) + sortedcode, sortedobjid_long = torch.sort(mortoncode, dim=-1) + sortedobjid = sortedobjid_long.int() + + if fixedorder: + nodechildren = torch.cat([ + torch.arange(1, (K - 1) * 2 + 1, dtype=torch.int32, device=dev), + torch.div(torch.arange(-2, -(K * 2 + 1) - 1, -1, dtype=torch.int32, device=dev), 2, rounding_mode="floor")], + dim=0).view(1, K + K - 1, 2).repeat(N, 1, 1) + nodeparent = ( + torch.div(torch.arange(-1, K * 2 - 2, dtype=torch.int32, device=dev), 2, rounding_mode="floor") + .view(1, -1).repeat(N, 1)) + else: + nodechildren = torch.empty((N, K + K - 1, 2), dtype=torch.int32, device=dev) + nodeparent = torch.full((N, K + K - 1), -1, dtype=torch.int32, device=dev) + mvpraymarchlib.build_tree(sortedcode, nodechildren, nodeparent) + + nodeaabb = torch.empty((N, K + K - 1, 2, 3), dtype=torch.float32, device=dev) + mvpraymarchlib.compute_aabb(*primtransfin, sortedobjid, nodechildren, nodeparent, nodeaabb, algo) + + return sortedobjid, nodechildren, nodeaabb + +class MVPRaymarch(Function): + """Custom Function for raymarching Mixture of Volumetric Primitives.""" + @staticmethod + def forward(self, raypos, raydir, stepsize, tminmax, + primpos, primrot, primscale, + template, warp, + rayterm, gradmode, options): + algo = options["algo"] + usebvh = options["usebvh"] + sortprims = options["sortprims"] + randomorder = options["randomorder"] + maxhitboxes = options["maxhitboxes"] + synchitboxes = options["synchitboxes"] + chlast = options["chlast"] + fadescale = options["fadescale"] + fadeexp = options["fadeexp"] + accum = options["accum"] + termthresh = options["termthresh"] + griddim = options["griddim"] + if isinstance(options["blocksize"], tuple): + blocksizex, blocksizey = options["blocksize"] + else: + blocksizex = options["blocksize"] + blocksizey = 1 + + assert raypos.is_contiguous() and raypos.size(3) == 3 + assert raydir.is_contiguous() and raydir.size(3) == 3 + assert tminmax.is_contiguous() and tminmax.size(3) == 2 + + assert primpos is None or primpos.is_contiguous() and primpos.size(2) == 3 + assert primrot is None or primrot.is_contiguous() and primrot.size(2) == 3 + assert primscale is None or primscale.is_contiguous() and primscale.size(2) == 3 + + if chlast: + assert template.is_contiguous() and len(template.size()) == 6 and template.size(-1) == 4 + assert warp is None or (warp.is_contiguous() and warp.size(-1) == 3) + else: + assert template.is_contiguous() and len(template.size()) == 6 and template.size(2) == 4 + assert warp is None or (warp.is_contiguous() and warp.size(2) == 3) + + primtransfin = (primpos, primrot, primscale) + + # Build bvh + if usebvh is not False: + # compute radius of primitives + sortedobjid, nodechildren, nodeaabb = build_accel(primtransfin, + algo, fixedorder=usebvh=="fixedorder") + assert sortedobjid.is_contiguous() + assert nodechildren.is_contiguous() + assert nodeaabb.is_contiguous() + + if randomorder: + sortedobjid = sortedobjid[torch.randperm(len(sortedobjid))] + else: + _, sortedobjid, nodechildren, nodeaabb = None, None, None, None + + # march through boxes + N, H, W = raypos.size(0), raypos.size(1), raypos.size(2) + rayrgba = torch.empty((N, H, W, 4), device=raypos.device) + if gradmode: + raysat = torch.full((N, H, W, 3), -1, dtype=torch.float32, device=raypos.device) + rayterm = None + else: + raysat = None + rayterm = None + + mvpraymarchlib.raymarch_forward( + raypos, raydir, stepsize, tminmax, + sortedobjid, nodechildren, nodeaabb, + *primtransfin, + template, warp, + rayrgba, raysat, rayterm, + algo, sortprims, maxhitboxes, synchitboxes, chlast, + fadescale, fadeexp, + accum, termthresh, + griddim, blocksizex, blocksizey) + + self.save_for_backward( + raypos, raydir, tminmax, + sortedobjid, nodechildren, nodeaabb, + primpos, primrot, primscale, + template, warp, + rayrgba, raysat, rayterm) + self.options = options + self.stepsize = stepsize + + return rayrgba + + @staticmethod + def backward(self, grad_rayrgba): + (raypos, raydir, tminmax, + sortedobjid, nodechildren, nodeaabb, + primpos, primrot, primscale, + template, warp, + rayrgba, raysat, rayterm) = self.saved_tensors + algo = self.options["algo"] + usebvh = self.options["usebvh"] + sortprims = self.options["sortprims"] + maxhitboxes = self.options["maxhitboxes"] + synchitboxes = self.options["synchitboxes"] + chlast = self.options["chlast"] + fadescale = self.options["fadescale"] + fadeexp = self.options["fadeexp"] + accum = self.options["accum"] + termthresh = self.options["termthresh"] + griddim = self.options["griddim"] + if isinstance(self.options["bwdblocksize"], tuple): + blocksizex, blocksizey = self.options["bwdblocksize"] + else: + blocksizex = self.options["bwdblocksize"] + blocksizey = 1 + + stepsize = self.stepsize + + grad_primpos = torch.zeros_like(primpos) + grad_primrot = torch.zeros_like(primrot) + grad_primscale = torch.zeros_like(primscale) + primtransfin = (primpos, grad_primpos, primrot, grad_primrot, primscale, grad_primscale) + + grad_template = torch.zeros_like(template) + grad_warp = torch.zeros_like(warp) if warp is not None else None + + mvpraymarchlib.raymarch_backward(raypos, raydir, stepsize, tminmax, + sortedobjid, nodechildren, nodeaabb, + + *primtransfin, + + template, grad_template, warp, grad_warp, + + rayrgba, grad_rayrgba.contiguous(), raysat, rayterm, + + algo, sortprims, maxhitboxes, synchitboxes, chlast, + fadescale, fadeexp, + accum, termthresh, + griddim, blocksizex, blocksizey) + + return (None, None, None, None, + grad_primpos, grad_primrot, grad_primscale, + grad_template, grad_warp, + None, None, None) + +def mvpraymarch(raypos, raydir, stepsize, tminmax, + primtransf, + template, warp, + rayterm=None, + algo=0, usebvh="fixedorder", + sortprims=False, randomorder=False, + maxhitboxes=512, synchitboxes=True, + chlast=True, fadescale=8., fadeexp=8., + accum=0, termthresh=0., + griddim=3, blocksize=(8, 16), bwdblocksize=(8, 16)): + """Main entry point for raymarching MVP. + + Parameters: + ---------- + raypos: N x H x W x 3 tensor of ray origins + raydir: N x H x W x 3 tensor of ray directions + stepsize: raymarching step size + tminmax: N x H x W x 2 tensor of raymarching min/max bounds + template: N x K x 4 x TD x TH x TW tensor of K RGBA primitives + warp: N x K x 3 x TD x TH x TW tensor of K warp fields (optional) + primpos: N x K x 3 tensor of primitive centers + primrot: N x K x 3 x 3 tensor of primitive orientations + primscale: N x K x 3 tensor of primitive inverse dimension lengths + algo: algorithm for raymarching (valid values: 0, 1). algo=0 is the fastest. + Currently algo=0 has a limit of 512 primitives per ray, so problems can + occur if there are many more boxes. all sortprims=True options have + this limitation, but you can use (algo=1, sortprims=False, + usebvh="fixedorder") which works correctly and has no primitive number + limitation (but is slightly slower). + usebvh: True to use bvh, "fixedorder" for a simple BVH, False for no bvh + sortprims: True to sort overlapping primitives at a sample point. Must + be True for gradients to match the PyTorch gradients. Seems unstable + if False but also not a big performance bottleneck. + chlast: whether template is provided as channels last or not. True tends + to be faster. + fadescale: Opacity is faded at the borders of the primitives by the equation + exp(-fadescale * x ** fadeexp) where x is the normalized coordinates of + the primitive. + fadeexp: Opacity is faded at the borders of the primitives by the equation + exp(-fadescale * x ** fadeexp) where x is the normalized coordinates of + the primitive. + griddim: CUDA grid dimensionality. + blocksize: blocksize of CUDA kernels. Should be 2-element tuple if + griddim>1, or integer if griddim==1.""" + if isinstance(primtransf, tuple): + primpos, primrot, primscale = primtransf + else: + primpos, primrot, primscale = ( + primtransf[:, :, 0, :].contiguous(), + primtransf[:, :, 1:4, :].contiguous(), + primtransf[:, :, 4, :].contiguous()) + primtransfin = (primpos, primrot, primscale) + + out = MVPRaymarch.apply(raypos, raydir, stepsize, tminmax, + *primtransfin, + template, warp, + rayterm, torch.is_grad_enabled(), + {"algo": algo, "usebvh": usebvh, "sortprims": sortprims, "randomorder": randomorder, + "maxhitboxes": maxhitboxes, "synchitboxes": synchitboxes, + "chlast": chlast, "fadescale": fadescale, "fadeexp": fadeexp, + "accum": accum, "termthresh": termthresh, + "griddim": griddim, "blocksize": blocksize, "bwdblocksize": bwdblocksize}) + return out + +class Rodrigues(nn.Module): + def __init__(self): + super(Rodrigues, self).__init__() + + def forward(self, rvec): + theta = torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=1)) + rvec = rvec / theta[:, None] + costh = torch.cos(theta) + sinth = torch.sin(theta) + return torch.stack(( + rvec[:, 0] ** 2 + (1. - rvec[:, 0] ** 2) * costh, + rvec[:, 0] * rvec[:, 1] * (1. - costh) - rvec[:, 2] * sinth, + rvec[:, 0] * rvec[:, 2] * (1. - costh) + rvec[:, 1] * sinth, + + rvec[:, 0] * rvec[:, 1] * (1. - costh) + rvec[:, 2] * sinth, + rvec[:, 1] ** 2 + (1. - rvec[:, 1] ** 2) * costh, + rvec[:, 1] * rvec[:, 2] * (1. - costh) - rvec[:, 0] * sinth, + + rvec[:, 0] * rvec[:, 2] * (1. - costh) - rvec[:, 1] * sinth, + rvec[:, 1] * rvec[:, 2] * (1. - costh) + rvec[:, 0] * sinth, + rvec[:, 2] ** 2 + (1. - rvec[:, 2] ** 2) * costh), dim=1).view(-1, 3, 3) + +def gradcheck(usebvh=True, sortprims=True, maxhitboxes=512, synchitboxes=False, + dowarp=False, chlast=False, fadescale=8., fadeexp=8., + accum=0, termthresh=0., algo=0, griddim=2, blocksize=(8, 16), bwdblocksize=(8, 16)): + N = 2 + H = 65 + W = 65 + k3 = 4 + K = k3*k3*k3 + + M = 32 + + print("=================================================================") + print("usebvh={}, sortprims={}, maxhb={}, synchb={}, dowarp={}, chlast={}, " + "fadescale={}, fadeexp={}, accum={}, termthresh={}, algo={}, griddim={}, " + "blocksize={}, bwdblocksize={}".format( + usebvh, sortprims, maxhitboxes, synchitboxes, dowarp, chlast, + fadescale, fadeexp, accum, termthresh, algo, griddim, blocksize, + bwdblocksize)) + + # generate random inputs + torch.manual_seed(1112) + + coherent_rays = True + if not coherent_rays: + _raypos = torch.randn(N, H, W, 3).to("cuda") + _raydir = torch.randn(N, H, W, 3).to("cuda") + _raydir /= torch.sqrt(torch.sum(_raydir ** 2, dim=-1, keepdim=True)) + else: + focal = torch.tensor([[W*4.0, W*4.0] for n in range(N)]) + princpt = torch.tensor([[W*0.5, H*0.5] for n in range(N)]) + pixely, pixelx = torch.meshgrid(torch.arange(H).float(), torch.arange(W).float()) + pixelcoords = torch.stack([pixelx, pixely], dim=-1)[None, :, :, :].repeat(N, 1, 1, 1) + + raydir = (pixelcoords - princpt[:, None, None, :]) / focal[:, None, None, :] + raydir = torch.cat([raydir, torch.ones_like(raydir[:, :, :, 0:1])], dim=-1) + raydir = raydir / torch.sqrt(torch.sum(raydir ** 2, dim=-1, keepdim=True)) + + _raypos = torch.tensor([-0.0, 0.0, -4.])[None, None, None, :].repeat(N, H, W, 1).to("cuda") + _raydir = raydir.to("cuda") + _raydir /= torch.sqrt(torch.sum(_raydir ** 2, dim=-1, keepdim=True)) + + max_len = 6.0 + _stepsize = max_len / 15.386928 + _tminmax = max_len*torch.arange(2, dtype=torch.float32)[None, None, None, :].repeat(N, H, W, 1).to("cuda") + \ + torch.rand(N, H, W, 2, device="cuda") * 1. + + _template = torch.randn(N, K, 4, M, M, M, requires_grad=True) + _template.data[:, :, -1, :, :, :] -= 3.5 + _template = _template.contiguous().detach().clone() + _template.requires_grad = True + gridxyz = torch.stack(torch.meshgrid( + torch.linspace(-1., 1., M//2), + torch.linspace(-1., 1., M//2), + torch.linspace(-1., 1., M//2))[::-1], dim=0).contiguous() + _warp = (torch.randn(N, K, 3, M//2, M//2, M//2) * 0.01 + gridxyz[None, None, :, :, :, :]).contiguous().detach().clone() + _warp.requires_grad = True + _primpos = torch.randn(N, K, 3, requires_grad=True) + _primpos = torch.randn(N, K, 3, requires_grad=True) + + coherent_centers = True + if coherent_centers: + ns = k3 + #assert ns*ns*ns==K + grid3d = torch.stack(torch.meshgrid( + torch.linspace(-1., 1., ns), + torch.linspace(-1., 1., ns), + torch.linspace(-1., 1., K//(ns*ns)))[::-1], dim=0)[None] + _primpos = (( + grid3d.permute((0, 2, 3, 4, 1)).reshape(1, K, 3).expand(N, -1, -1) + + 0.1 * torch.randn(N, K, 3, requires_grad=True) + )).contiguous().detach().clone() + _primpos.requires_grad = True + scale_ws = 1. + _primrot = torch.randn(N, K, 3) + rodrigues = Rodrigues() + _primrot = rodrigues(_primrot.view(-1, 3)).view(N, K, 3, 3).contiguous().detach().clone() + _primrot.requires_grad = True + + _primscale = torch.randn(N, K, 3, requires_grad=True) + _primscale.data *= 0.0 + + if dowarp: + params = [_template, _warp, _primscale, _primrot, _primpos] + paramnames = ["template", "warp", "primscale", "primrot", "primpos"] + else: + params = [_template, _primscale, _primrot, _primpos] + paramnames = ["template", "primscale", "primrot", "primpos"] + + termthreshorig = termthresh + + ########################### run pytorch version ########################### + + raypos = _raypos + raydir = _raydir + stepsize = _stepsize + tminmax = _tminmax + + #template = F.softplus(_template.to("cuda") * 1.5) + template = F.softplus(_template.to("cuda") * 1.5) if algo != 2 else _template.to("cuda") * 1.5 + warp = _warp.to("cuda") + primpos = _primpos.to("cuda") * 0.3 + primrot = _primrot.to("cuda") + primscale = scale_ws * torch.exp(0.1 * _primscale.to("cuda")) + + # python raymarching implementation + rayrgba = torch.zeros((N, H, W, 4)).to("cuda") + raypos = raypos + raydir * tminmax[:, :, :, 0, None] + t = tminmax[:, :, :, 0] + + step = 0 + t0 = t.detach().clone() + raypos0 = raypos.detach().clone() + + torch.cuda.synchronize() + time0 = time.time() + + while (t < tminmax[:, :, :, 1]).any(): + valid2 = torch.ones_like(rayrgba[:, :, :, 3:4]) + + for k in range(K): + y0 = torch.bmm( + (raypos - primpos[:, k, None, None, :]).view(raypos.size(0), -1, raypos.size(3)), + primrot[:, k, :, :]).view_as(raypos) * primscale[:, k, None, None, :] + + fade = torch.exp(-fadescale * torch.sum(torch.abs(y0) ** fadeexp, dim=-1, keepdim=True)) + + if dowarp: + y1 = F.grid_sample( + warp[:, k, :, :, :, :], + y0[:, None, :, :, :], align_corners=True)[:, :, 0, :, :].permute(0, 2, 3, 1) + else: + y1 = y0 + + sample = F.grid_sample( + template[:, k, :, :, :, :], + y1[:, None, :, :, :], align_corners=True)[:, :, 0, :, :].permute(0, 2, 3, 1) + + valid1 = ( + torch.prod(y0[:, :, :, :] >= -1., dim=-1, keepdim=True) * + torch.prod(y0[:, :, :, :] <= 1., dim=-1, keepdim=True)) + + valid = ((t >= tminmax[:, :, :, 0]) & (t < tminmax[:, :, :, 1])).float()[:, :, :, None] + + alpha0 = sample[:, :, :, 3:4] + + rgb = sample[:, :, :, 0:3] * valid * valid1 + alpha = alpha0 * fade * stepsize * valid * valid1 + + if accum == 0: + newalpha = rayrgba[:, :, :, 3:4] + alpha + contrib = (newalpha.clamp(max=1.0) - rayrgba[:, :, :, 3:4]) * valid * valid1 + rayrgba = rayrgba + contrib * torch.cat([rgb, torch.ones_like(alpha)], dim=-1) + else: + raise + + step += 1 + t = t0 + stepsize * step + raypos = raypos0 + raydir * stepsize * step + + print(rayrgba[..., -1].min().item(), rayrgba[..., -1].max().item()) + + sample0 = rayrgba + + torch.cuda.synchronize() + time1 = time.time() + + sample0.backward(torch.ones_like(sample0)) + + torch.cuda.synchronize() + time2 = time.time() + + print("{:<10} {:>10} {:>10} {:>10}".format("", "fwd", "bwd", "total")) + print("{:<10} {:10.5} {:10.5} {:10.5}".format("pytime", time1 - time0, time2 - time1, time2 - time0)) + + grads0 = [p.grad.detach().clone() for p in params] + + for p in params: + p.grad.detach_() + p.grad.zero_() + + ############################## run cuda version ########################### + + raypos = _raypos + raydir = _raydir + stepsize = _stepsize + tminmax = _tminmax + + template = F.softplus(_template.to("cuda") * 1.5) if algo != 2 else _template.to("cuda") * 1.5 + warp = _warp.to("cuda") + if chlast: + template = template.permute(0, 1, 3, 4, 5, 2).contiguous() + warp = warp.permute(0, 1, 3, 4, 5, 2).contiguous() + primpos = _primpos.to("cuda") * 0.3 + primrot = _primrot.to("cuda") + primscale = scale_ws * torch.exp(0.1 * _primscale.to("cuda")) + + niter = 1 + + tf, tb = 0., 0. + for i in range(niter): + for p in params: + try: + p.grad.detach_() + p.grad.zero_() + except: + pass + t0 = time.time() + torch.cuda.synchronize() + sample1 = mvpraymarch(raypos, raydir, stepsize, tminmax, + (primpos, primrot, primscale), + template, warp if dowarp else None, + algo=algo, usebvh=usebvh, sortprims=sortprims, + maxhitboxes=maxhitboxes, synchitboxes=synchitboxes, + chlast=chlast, fadescale=fadescale, fadeexp=fadeexp, + accum=accum, termthresh=termthreshorig, + griddim=griddim, blocksize=blocksize, bwdblocksize=bwdblocksize) + t1 = time.time() + torch.cuda.synchronize() + sample1.backward(torch.ones_like(sample1), retain_graph=True) + torch.cuda.synchronize() + t2 = time.time() + tf += t1 - t0 + tb += t2 - t1 + + print("{:<10} {:10.5} {:10.5} {:10.5}".format("time", tf / niter, tb / niter, (tf + tb) / niter)) + grads1 = [p.grad.detach().clone() for p in params] + + ############# compare results ############# + + print("-----------------------------------------------------------------") + print("{:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10}".format("", "maxabsdiff", "dp", "||py||", "||cuda||", "index", "py", "cuda")) + ind = torch.argmax(torch.abs(sample0 - sample1)) + print("{:<10} {:>10.5} {:>10.5} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format( + "fwd", + torch.max(torch.abs(sample0 - sample1)).item(), + (torch.sum(sample0 * sample1) / torch.sqrt(torch.sum(sample0 * sample0) * torch.sum(sample1 * sample1))).item(), + torch.sqrt(torch.sum(sample0 * sample0)).item(), + torch.sqrt(torch.sum(sample1 * sample1)).item(), + ind.item(), + sample0.view(-1)[ind].item(), + sample1.view(-1)[ind].item())) + + for p, g0, g1 in zip(paramnames, grads0, grads1): + ind = torch.argmax(torch.abs(g0 - g1)) + print("{:<10} {:>10.5} {:>10.5} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format( + p, + torch.max(torch.abs(g0 - g1)).item(), + (torch.sum(g0 * g1) / torch.sqrt(torch.sum(g0 * g0) * torch.sum(g1 * g1))).item(), + torch.sqrt(torch.sum(g0 * g0)).item(), + torch.sqrt(torch.sum(g1 * g1)).item(), + ind.item(), + g0.view(-1)[ind].item(), + g1.view(-1)[ind].item())) + +if __name__ == "__main__": + gradcheck(usebvh="fixedorder", sortprims=False, maxhitboxes=512, synchitboxes=True, + dowarp=False, chlast=True, fadescale=6.5, fadeexp=7.5, accum=0, algo=0, griddim=3) + gradcheck(usebvh="fixedorder", sortprims=False, maxhitboxes=512, synchitboxes=True, + dowarp=True, chlast=True, fadescale=6.5, fadeexp=7.5, accum=0, algo=1, griddim=3) diff --git a/dva/mvp/extensions/mvpraymarch/mvpraymarch_kernel.cu b/dva/mvp/extensions/mvpraymarch/mvpraymarch_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..0f25e64af3af2d2eabe40150f884e2e7eee9a972 --- /dev/null +++ b/dva/mvp/extensions/mvpraymarch/mvpraymarch_kernel.cu @@ -0,0 +1,208 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include +#include + +#include "helper_math.h" + +#include "cudadispatch.h" + +#include "utils.h" + +#include "primtransf.h" +#include "primsampler.h" +#include "primaccum.h" + +#include "mvpraymarch_subset_kernel.h" + +typedef std::shared_ptr PrimTransfDataBase_ptr; +typedef std::shared_ptr PrimSamplerDataBase_ptr; +typedef std::shared_ptr PrimAccumDataBase_ptr; +typedef std::function mapfn_t; +typedef RaySubsetFixedBVH raysubset_t; + +void raymarch_forward_cuda( + int N, int H, int W, int K, + float * rayposim, + float * raydirim, + float stepsize, + float * tminmaxim, + + int * sortedobjid, + int * nodechildren, + float * nodeaabb, + float * primpos, + float * primrot, + float * primscale, + + int TD, int TH, int TW, + float * tplate, + int WD, int WH, int WW, + float * warp, + + float * rayrgbaim, + float * raysatim, + int * raytermim, + + int algorithm, + bool sortboxes, + int maxhitboxes, + bool synchitboxes, + bool chlast, + float fadescale, + float fadeexp, + int accum, + float termthresh, + int griddim, int blocksizex, int blocksizey, + cudaStream_t stream) { + dim3 blocksize(blocksizex, blocksizey); + dim3 gridsize; + gridsize = dim3( + (W + blocksize.x - 1) / blocksize.x, + (H + blocksize.y - 1) / blocksize.y, + N); + + std::shared_ptr primtransf_data; + primtransf_data = std::make_shared(PrimTransfSRT::Data{ + PrimTransfDataBase{}, + K, (float3*)primpos, nullptr, + K * 3, (float3*)primrot, nullptr, + K, (float3*)primscale, nullptr}); + std::shared_ptr primsampler_data; + if (algorithm == 1) { + primsampler_data = std::make_shared::Data>(PrimSamplerTW::Data{ + PrimSamplerDataBase{}, + fadescale, fadeexp, + K * TD * TH * TW * 4, TD, TH, TW, tplate, nullptr, + K * WD * WH * WW * 3, WD, WH, WW, warp, nullptr}); + } else { + primsampler_data = std::make_shared::Data>(PrimSamplerTW::Data{ + PrimSamplerDataBase{}, + fadescale, fadeexp, + K * TD * TH * TW * 4, TD, TH, TW, tplate, nullptr, + 0, 0, 0, 0, nullptr, nullptr}); + } + std::shared_ptr primaccum_data = std::make_shared(PrimAccumAdditive::Data{ + PrimAccumDataBase{}, + termthresh, H * W, W, 1, (float4*)rayrgbaim, nullptr, (float3*)raysatim}); + + std::map dispatcher = { + {0, make_cudacall(raymarch_subset_forward_kernel<512, 4, raysubset_t, PrimTransfSRT, PrimSamplerTW, PrimAccumAdditive>)}, + {1, make_cudacall(raymarch_subset_forward_kernel<512, 4, raysubset_t, PrimTransfSRT, PrimSamplerTW, PrimAccumAdditive>)}}; + + auto iter = dispatcher.find(algorithm); + if (iter != dispatcher.end()) { + (iter->second)( + gridsize, blocksize, stream, + N, H, W, K, + reinterpret_cast(rayposim), + reinterpret_cast(raydirim), + stepsize, + reinterpret_cast(tminmaxim), + reinterpret_cast(sortedobjid), + reinterpret_cast(nodechildren), + reinterpret_cast(nodeaabb), + primtransf_data, + primsampler_data, + primaccum_data); + } +} + +void raymarch_backward_cuda( + int N, int H, int W, int K, + float * rayposim, + float * raydirim, + float stepsize, + float * tminmaxim, + int * sortedobjid, + int * nodechildren, + float * nodeaabb, + + float * primpos, + float * grad_primpos, + float * primrot, + float * grad_primrot, + float * primscale, + float * grad_primscale, + + int TD, int TH, int TW, + float * tplate, + float * grad_tplate, + int WD, int WH, int WW, + float * warp, + float * grad_warp, + + float * rayrgbaim, + float * grad_rayrgba, + float * raysatim, + int * raytermim, + + int algorithm, bool sortboxes, int maxhitboxes, bool synchitboxes, + bool chlast, float fadescale, float fadeexp, int accum, float termthresh, + int griddim, int blocksizex, int blocksizey, + + cudaStream_t stream) { + dim3 blocksize(blocksizex, blocksizey); + dim3 gridsize; + gridsize = dim3( + (W + blocksize.x - 1) / blocksize.x, + (H + blocksize.y - 1) / blocksize.y, + N); + + std::shared_ptr primtransf_data; + primtransf_data = std::make_shared(PrimTransfSRT::Data{ + PrimTransfDataBase{}, + K, (float3*)primpos, (float3*)grad_primpos, + K * 3, (float3*)primrot, (float3*)grad_primrot, + K, (float3*)primscale, (float3*)grad_primscale}); + std::shared_ptr primsampler_data; + if (algorithm == 1) { + primsampler_data = std::make_shared::Data>(PrimSamplerTW::Data{ + PrimSamplerDataBase{}, + fadescale, fadeexp, + K * TD * TH * TW * 4, TD, TH, TW, tplate, grad_tplate, + K * WD * WH * WW * 3, WD, WH, WW, warp, grad_warp}); + } else { + primsampler_data = std::make_shared::Data>(PrimSamplerTW::Data{ + PrimSamplerDataBase{}, + fadescale, fadeexp, + K * TD * TH * TW * 4, TD, TH, TW, tplate, grad_tplate, + 0, 0, 0, 0, nullptr, nullptr}); + } + std::shared_ptr primaccum_data = std::make_shared(PrimAccumAdditive::Data{ + PrimAccumDataBase{}, + termthresh, H * W, W, 1, (float4*)rayrgbaim, (float4*)grad_rayrgba, (float3*)raysatim}); + + std::map dispatcher = { + {0, make_cudacall(raymarch_subset_backward_kernel, PrimAccumAdditive>)}, + {1, make_cudacall(raymarch_subset_backward_kernel, PrimAccumAdditive>)}}; + + auto iter = dispatcher.find(algorithm); + if (iter != dispatcher.end()) { + (iter->second)( + gridsize, blocksize, stream, + N, H, W, K, + reinterpret_cast(rayposim), + reinterpret_cast(raydirim), + stepsize, + reinterpret_cast(tminmaxim), + reinterpret_cast(sortedobjid), + reinterpret_cast(nodechildren), + reinterpret_cast(nodeaabb), + primtransf_data, + primsampler_data, + primaccum_data); + } +} diff --git a/dva/mvp/extensions/mvpraymarch/mvpraymarch_subset_kernel.h b/dva/mvp/extensions/mvpraymarch/mvpraymarch_subset_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..4586af383e8e9c39d479d41e9565d446000c379f --- /dev/null +++ b/dva/mvp/extensions/mvpraymarch/mvpraymarch_subset_kernel.h @@ -0,0 +1,218 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +template< + int maxhitboxes, + int nwarps, + class RaySubsetT=RaySubsetFixedBVH, + class PrimTransfT=PrimTransfSRT, + class PrimSamplerT=PrimSamplerTW, + class PrimAccumT=PrimAccumAdditive> +__global__ void raymarch_subset_forward_kernel( + int N, int H, int W, int K, + float3 * rayposim, + float3 * raydirim, + float stepsize, + float2 * tminmaxim, + int * sortedobjid, + int2 * nodechildren, + float3 * nodeaabb, + typename PrimTransfT::Data primtransf_data, + typename PrimSamplerT::Data primsampler_data, + typename PrimAccumT::Data primaccum_data + ) { + int w = blockIdx.x * blockDim.x + threadIdx.x; + int h = blockIdx.y * blockDim.y + threadIdx.y; + int n = blockIdx.z; + bool validthread = (w < W) && (h < H) && (n 0 ? 1 : maxhitboxes]; + __shared__ int hitboxes_sh[nwarps > 0 ? maxhitboxes * nwarps : 1]; + int * hitboxes_ptr = nwarps > 0 ? hitboxes_sh + maxhitboxes * warpid : hitboxes; + int nhitboxes = 0; + + // find raytminmax + float2 rtminmax = make_float2(std::numeric_limits::infinity(), -std::numeric_limits::infinity()); + RaySubsetT::forward(warpmask, K, raypos, raydir, tminmax, rtminmax, + sortedobjid, nodechildren, nodeaabb, + primtransf_data, hitboxes_ptr, nhitboxes); + rtminmax.x = max(rtminmax.x, tminmax.x); + rtminmax.y = min(rtminmax.y, tminmax.y); + __syncwarp(warpmask); + + float t = tminmax.x; + raypos = raypos + raydir * tminmax.x; + + int incs = floor((rtminmax.x - t) / stepsize); + t += incs * stepsize; + raypos += raydir * incs * stepsize; + + PrimAccumT pa; + + while (!__all_sync(warpmask, t > rtminmax.y + 1e-5f || pa.is_done())) { + for (int ks = 0; ks < nhitboxes; ++ks) { + int k = hitboxes_ptr[ks]; + + // compute primitive-relative coordinate + PrimTransfT pt; + float3 samplepos = pt.forward(primtransf_data, k, raypos); + + if (pt.valid(samplepos) && !pa.is_done() && t < rtminmax.y + 1e-5f) { + // sample + PrimSamplerT ps; + float4 sample = ps.forward(primsampler_data, k, samplepos); + + // accumulate + pa.forward_prim(primaccum_data, sample, stepsize); + } + } + + // update position + t += stepsize; + raypos += raydir * stepsize; + } + + pa.write(primaccum_data); +} + +template < + bool forwarddir, + int maxhitboxes, + int nwarps, + class RaySubsetT=RaySubsetFixedBVH, + class PrimTransfT=PrimTransfSRT, + class PrimSamplerT=PrimSamplerTW, + class PrimAccumT=PrimAccumAdditive> +__global__ void raymarch_subset_backward_kernel( + int N, int H, int W, int K, + float3 * rayposim, + float3 * raydirim, + float stepsize, + float2 * tminmaxim, + int * sortedobjid, + int2 * nodechildren, + float3 * nodeaabb, + typename PrimTransfT::Data primtransf_data, + typename PrimSamplerT::Data primsampler_data, + typename PrimAccumT::Data primaccum_data + ) { + int w = blockIdx.x * blockDim.x + threadIdx.x; + int h = blockIdx.y * blockDim.y + threadIdx.y; + int n = blockIdx.z; + bool validthread = (w < W) && (h < H) && (n 0 ? 1 : maxhitboxes]; + __shared__ int hitboxes_sh[nwarps > 0 ? maxhitboxes * nwarps : 1]; + int * hitboxes_ptr = nwarps > 0 ? hitboxes_sh + maxhitboxes * warpid : hitboxes; + int nhitboxes = 0; + + // find raytminmax + float2 rtminmax = make_float2(std::numeric_limits::infinity(), -std::numeric_limits::infinity()); + RaySubsetT::forward(warpmask, K, raypos, raydir, tminmax, rtminmax, + sortedobjid, nodechildren, nodeaabb, + primtransf_data, hitboxes_ptr, nhitboxes); + rtminmax.x = max(rtminmax.x, tminmax.x); + rtminmax.y = min(rtminmax.y, tminmax.y); + __syncwarp(warpmask); + + // set up raymarching position + float t = tminmax.x; + raypos = raypos + raydir * tminmax.x; + + int incs = floor((rtminmax.x - t) / stepsize); + t += incs * stepsize; + raypos += raydir * incs * stepsize; + + if (!forwarddir) { + int nsteps = pa.get_nsteps(); + t += nsteps * stepsize; + raypos += raydir * nsteps * stepsize; + } + + while (__any_sync(warpmask, ( + (forwarddir && t < rtminmax.y + 1e-5f || + !forwarddir && t > rtminmax.x - 1e-5f) && + !pa.is_done()))) { + for (int ks = 0; ks < nhitboxes; ++ks) { + int k = hitboxes_ptr[forwarddir ? ks : nhitboxes - ks - 1]; + + PrimTransfT pt; + float3 samplepos = pt.forward(primtransf_data, k, raypos); + + bool evalprim = pt.valid(samplepos) && !pa.is_done() && t < rtminmax.y + 1e-5f; + + float3 dL_samplepos = make_float3(0.f); + if (evalprim) { + PrimSamplerT ps; + float4 sample = ps.forward(primsampler_data, k, samplepos); + + float4 dL_sample = pa.forwardbackward_prim(primaccum_data, sample, stepsize); + + dL_samplepos = ps.backward(primsampler_data, k, samplepos, sample, dL_sample, validthread); + } + + if (__any_sync(warpmask, evalprim)) { + pt.backward(primtransf_data, k, samplepos, dL_samplepos, validthread && evalprim); + } + } + + if (forwarddir) { + t += stepsize; + raypos += raydir * stepsize; + } else { + t -= stepsize; + raypos -= raydir * stepsize; + } + } +} + diff --git a/dva/mvp/extensions/mvpraymarch/primaccum.h b/dva/mvp/extensions/mvpraymarch/primaccum.h new file mode 100644 index 0000000000000000000000000000000000000000..200beb1f2e6cd64449351344a56481cf3395a1f3 --- /dev/null +++ b/dva/mvp/extensions/mvpraymarch/primaccum.h @@ -0,0 +1,101 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#ifndef MVPRAYMARCHER_PRIMACCUM_H_ +#define MVPRAYMARCHER_PRIMACCUM_H_ + +struct PrimAccumDataBase { + typedef PrimAccumDataBase base; +}; + +struct PrimAccumAdditive { + struct Data : public PrimAccumDataBase { + float termthresh; + + int nstride, hstride, wstride; + float4 * rayrgbaim; + float4 * grad_rayrgbaim; + float3 * raysatim; + + __forceinline__ __device__ void n_stride(int n, int h, int w) { + rayrgbaim += n * nstride + h * hstride + w * wstride; + grad_rayrgbaim += n * nstride + h * hstride + w * wstride; + if (raysatim) { + raysatim += n * nstride + h * hstride + w * wstride; + } + } + }; + + float4 rayrgba; + float3 raysat; + bool sat; + float4 dL_rayrgba; + + __forceinline__ __device__ PrimAccumAdditive() : + rayrgba(make_float4(0.f)), + raysat(make_float3(-1.f)), + sat(false) { + } + + __forceinline__ __device__ bool is_done() const { + return sat; + } + + __forceinline__ __device__ int get_nsteps() const { + return 0; + } + + __forceinline__ __device__ void write(const Data & data) { + *data.rayrgbaim = rayrgba; + if (data.raysatim) { + *data.raysatim = raysat; + } + } + + __forceinline__ __device__ void read(const Data & data) { + dL_rayrgba = *data.grad_rayrgbaim; + raysat = *data.raysatim; + } + + __forceinline__ __device__ void forward_prim(const Data & data, float4 sample, float stepsize) { + // accumulate + float3 rgb = make_float3(sample); + float alpha = sample.w; + float newalpha = rayrgba.w + alpha * stepsize; + float contrib = fminf(newalpha, 1.f) - rayrgba.w; + + rayrgba += make_float4(rgb, 1.f) * contrib; + + if (newalpha >= 1.f) { + // save saturation point + if (!sat) { + raysat = rgb; + } + sat = true; + } + } + + __forceinline__ __device__ float4 forwardbackward_prim(const Data & data, float4 sample, float stepsize) { + float3 rgb = make_float3(sample); + float4 rgb1 = make_float4(rgb, 1.f); + sample.w *= stepsize; + + bool thissat = rayrgba.w + sample.w >= 1.f; + sat = sat || thissat; + + float weight = sat ? (1.f - rayrgba.w) : sample.w; + + float3 dL_rgb = weight * make_float3(dL_rayrgba); + float dL_alpha = sat ? 0.f : + stepsize * dot(rgb1 - (raysat.x > -1.f ? make_float4(raysat, 1.f) : make_float4(0.f)), dL_rayrgba); + + rayrgba += make_float4(rgb, 1.f) * weight; + + return make_float4(dL_rgb, dL_alpha); + } +}; + +#endif diff --git a/dva/mvp/extensions/mvpraymarch/primsampler.h b/dva/mvp/extensions/mvpraymarch/primsampler.h new file mode 100644 index 0000000000000000000000000000000000000000..94cf4ed7db78d1fbe99709e657dc6a13e868f804 --- /dev/null +++ b/dva/mvp/extensions/mvpraymarch/primsampler.h @@ -0,0 +1,94 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#ifndef MVPRAYMARCHER_PRIMSAMPLER_H_ +#define MVPRAYMARCHER_PRIMSAMPLER_H_ + +struct PrimSamplerDataBase { + typedef PrimSamplerDataBase base; +}; + +template< + bool dowarp, + template class GridSamplerT=GridSamplerChlast> +struct PrimSamplerTW { + struct Data : public PrimSamplerDataBase { + float fadescale, fadeexp; + + int tplate_nstride; + int TD, TH, TW; + float * tplate; + float * grad_tplate; + + int warp_nstride; + int WD, WH, WW; + float * warp; + float * grad_warp; + + __forceinline__ __device__ void n_stride(int n) { + tplate += n * tplate_nstride; + grad_tplate += n * tplate_nstride; + warp += n * warp_nstride; + grad_warp += n * warp_nstride; + } + }; + + float fade; + float * tplate_ptr; + float * warp_ptr; + float3 yy1; + + __forceinline__ __device__ float4 forward( + const Data & data, + int k, + float3 y0) { + fade = __expf(-data.fadescale * ( + __powf(abs(y0.x), data.fadeexp) + + __powf(abs(y0.y), data.fadeexp) + + __powf(abs(y0.z), data.fadeexp))); + + if (dowarp) { + warp_ptr = data.warp + (k * 3 * data.WD * data.WH * data.WW); + yy1 = GridSamplerT::forward(3, data.WD, data.WH, data.WW, warp_ptr, y0, false); + } else { + yy1 = y0; + } + + tplate_ptr = data.tplate + (k * 4 * data.TD * data.TH * data.TW); + float4 sample = GridSamplerT::forward(4, data.TD, data.TH, data.TW, tplate_ptr, yy1, false); + + sample.w *= fade; + + return sample; + } + + __forceinline__ __device__ float3 backward(const Data & data, int k, float3 y0, + float4 sample, float4 dL_sample, bool validthread) { + float3 dfade_y0 = -(data.fadescale * data.fadeexp) * make_float3( + __powf(abs(y0.x), data.fadeexp - 1.f) * (y0.x > 0.f ? 1.f : -1.f), + __powf(abs(y0.y), data.fadeexp - 1.f) * (y0.y > 0.f ? 1.f : -1.f), + __powf(abs(y0.z), data.fadeexp - 1.f) * (y0.z > 0.f ? 1.f : -1.f)); + float3 dL_y0 = dfade_y0 * sample.w * dL_sample.w; + + dL_sample.w *= fade; + + float * grad_tplate_ptr = data.grad_tplate + (k * 4 * data.TD * data.TH * data.TW); + float3 dL_y1 = GridSamplerT::backward(4, data.TD, data.TH, data.TW, + tplate_ptr, grad_tplate_ptr, yy1, validthread ? dL_sample : make_float4(0.f), false); + + if (dowarp) { + float * grad_warp_ptr = data.grad_warp + (k * 3 * data.WD * data.WH * data.WW); + dL_y0 += GridSamplerT::backward(3, data.WD, data.WH, data.WW, + warp_ptr, grad_warp_ptr, y0, validthread ? dL_y1 : make_float3(0.f), false); + } else { + dL_y0 += dL_y1; + } + + return dL_y0; + } +}; + +#endif diff --git a/dva/mvp/extensions/mvpraymarch/primtransf.h b/dva/mvp/extensions/mvpraymarch/primtransf.h new file mode 100644 index 0000000000000000000000000000000000000000..7c24f3247c255b6da72d7ee6028ff2856af64ea5 --- /dev/null +++ b/dva/mvp/extensions/mvpraymarch/primtransf.h @@ -0,0 +1,182 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#ifndef MVPRAYMARCHER_PRIMTRANSF_H_ +#define MVPRAYMARCHER_PRIMTRANSF_H_ + +#include "utils.h" + +__forceinline__ __device__ void compute_aabb_srt( + float3 pt, float3 pr0, float3 pr1, float3 pr2, float3 ps, + float3 & pmin, float3 & pmax) { + float3 p; + p = make_float3(-1.f, -1.f, -1.f) / ps; + p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; + + pmin = p; + pmax = p; + + p = make_float3(1.f, -1.f, -1.f) / ps; + p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; + + pmin = fminf(pmin, p); + pmax = fmaxf(pmax, p); + + p = make_float3(-1.f, 1.f, -1.f) / ps; + p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; + + pmin = fminf(pmin, p); + pmax = fmaxf(pmax, p); + + p = make_float3(1.f, 1.f, -1.f) / ps; + p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; + + pmin = fminf(pmin, p); + pmax = fmaxf(pmax, p); + + p = make_float3(-1.f, -1.f, 1.f) / ps; + p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; + + pmin = fminf(pmin, p); + pmax = fmaxf(pmax, p); + + p = make_float3(1.f, -1.f, 1.f) / ps; + p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; + + pmin = fminf(pmin, p); + pmax = fmaxf(pmax, p); + + p = make_float3(-1.f, 1.f, 1.f) / ps; + p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; + + pmin = fminf(pmin, p); + pmax = fmaxf(pmax, p); + + p = make_float3(1.f, 1.f, 1.f) / ps; + p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; + + pmin = fminf(pmin, p); + pmax = fmaxf(pmax, p); +} + +struct PrimTransfDataBase { + typedef PrimTransfDataBase base; +}; + +struct PrimTransfSRT { + struct Data : public PrimTransfDataBase { + int primpos_nstride; + float3 * primpos; + float3 * grad_primpos; + int primrot_nstride; + float3 * primrot; + float3 * grad_primrot; + int primscale_nstride; + float3 * primscale; + float3 * grad_primscale; + + __forceinline__ __device__ void n_stride(int n) { + primpos += n * primpos_nstride; + grad_primpos += n * primpos_nstride; + primrot += n * primrot_nstride; + grad_primrot += n * primrot_nstride; + primscale += n * primscale_nstride; + grad_primscale += n * primscale_nstride; + } + + __forceinline__ __device__ float3 get_center(int n, int k) { + return primpos[n * primpos_nstride + k]; + } + + __forceinline__ __device__ void compute_aabb(int n, int k, float3 & pmin, float3 & pmax) { + float3 pt = primpos[n * primpos_nstride + k]; + float3 pr0 = primrot[n * primrot_nstride + k * 3 + 0]; + float3 pr1 = primrot[n * primrot_nstride + k * 3 + 1]; + float3 pr2 = primrot[n * primrot_nstride + k * 3 + 2]; + float3 ps = primscale[n * primscale_nstride + k]; + + compute_aabb_srt(pt, pr0, pr1, pr2, ps, pmin, pmax); + } + }; + + float3 xmt; + float3 pr0; + float3 pr1; + float3 pr2; + float3 rxmt; + float3 ps; + + static __forceinline__ __device__ bool valid(float3 pos) { + return ( + pos.x > -1.f && pos.x < 1.f && + pos.y > -1.f && pos.y < 1.f && + pos.z > -1.f && pos.z < 1.f); + } + + __forceinline__ __device__ float3 forward( + const Data & data, + int k, + float3 x) { + float3 pt = data.primpos[k]; + pr0 = data.primrot[(k) * 3 + 0]; + pr1 = data.primrot[(k) * 3 + 1]; + pr2 = data.primrot[(k) * 3 + 2]; + ps = data.primscale[k]; + xmt = x - pt; + rxmt = pr0 * xmt.x + pr1 * xmt.y + pr2 * xmt.z; + float3 y0 = rxmt * ps; + return y0; + } + + static __forceinline__ __device__ void forward2( + const Data & data, + int k, + float3 r, float3 d, float3 & rout, float3 & dout) { + float3 pt = data.primpos[k]; + float3 pr0 = data.primrot[k * 3 + 0]; + float3 pr1 = data.primrot[k * 3 + 1]; + float3 pr2 = data.primrot[k * 3 + 2]; + float3 ps = data.primscale[k]; + float3 xmt = r - pt; + float3 dmt = d; + float3 rxmt = pr0 * xmt.x; + float3 rdmt = pr0 * dmt.x; + rxmt += pr1 * xmt.y; + rdmt += pr1 * dmt.y; + rxmt += pr2 * xmt.z; + rdmt += pr2 * dmt.z; + rout = rxmt * ps; + dout = rdmt * ps; + } + + __forceinline__ __device__ void backward(const Data & data, int k, float3 x, float3 dL_y0, bool validthread) { + fastAtomicAdd((float*)data.grad_primscale + k * 3 + 0, validthread ? rxmt.x * dL_y0.x : 0.f); + fastAtomicAdd((float*)data.grad_primscale + k * 3 + 1, validthread ? rxmt.y * dL_y0.y : 0.f); + fastAtomicAdd((float*)data.grad_primscale + k * 3 + 2, validthread ? rxmt.z * dL_y0.z : 0.f); + + dL_y0 *= ps; + float3 gpr0 = xmt.x * dL_y0; + fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 0) * 3 + 0, validthread ? gpr0.x : 0.f); + fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 0) * 3 + 1, validthread ? gpr0.y : 0.f); + fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 0) * 3 + 2, validthread ? gpr0.z : 0.f); + + float3 gpr1 = xmt.y * dL_y0; + fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 1) * 3 + 0, validthread ? gpr1.x : 0.f); + fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 1) * 3 + 1, validthread ? gpr1.y : 0.f); + fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 1) * 3 + 2, validthread ? gpr1.z : 0.f); + + float3 gpr2 = xmt.z * dL_y0; + fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 2) * 3 + 0, validthread ? gpr2.x : 0.f); + fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 2) * 3 + 1, validthread ? gpr2.y : 0.f); + fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 2) * 3 + 2, validthread ? gpr2.z : 0.f); + + fastAtomicAdd((float*)data.grad_primpos + k * 3 + 0, validthread ? -dot(pr0, dL_y0) : 0.f); + fastAtomicAdd((float*)data.grad_primpos + k * 3 + 1, validthread ? -dot(pr1, dL_y0) : 0.f); + fastAtomicAdd((float*)data.grad_primpos + k * 3 + 2, validthread ? -dot(pr2, dL_y0) : 0.f); + } +}; + +#endif diff --git a/dva/mvp/extensions/mvpraymarch/setup.py b/dva/mvp/extensions/mvpraymarch/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..b967dff8adb8aa63ea05290e0beda13225ae73d9 --- /dev/null +++ b/dva/mvp/extensions/mvpraymarch/setup.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from setuptools import setup + +from torch.utils.cpp_extension import CUDAExtension, BuildExtension + +if __name__ == "__main__": + import torch + setup( + name="mvpraymarch", + ext_modules=[ + CUDAExtension( + "mvpraymarchlib", + sources=["mvpraymarch.cpp", "mvpraymarch_kernel.cu", "bvh.cu"], + extra_compile_args={ + "nvcc": [ + "-use_fast_math", + "-arch=sm_70", + "-std=c++17", + "-lineinfo", + ] + } + ) + ], + cmdclass={"build_ext": BuildExtension} + ) diff --git a/dva/mvp/extensions/mvpraymarch/utils.h b/dva/mvp/extensions/mvpraymarch/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..21d2766d5b18982016b6a5d9cbbb5e351d1f222f --- /dev/null +++ b/dva/mvp/extensions/mvpraymarch/utils.h @@ -0,0 +1,847 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#ifndef MVPRAYMARCHER_UTILS_H_ +#define MVPRAYMARCHER_UTILS_H_ + +#include +#include + +#include + +#include "helper_math.h" + +static __forceinline__ __device__ float clock_diff(long long int end, long long int start) { + long long int max_clock = std::numeric_limits::max(); + return (end= b.x && a.y >= b.y && a.z >= b.z; +} + +static __forceinline__ __device__ +bool alllt(float3 a, float3 b) { + return a.x <= b.x && a.y <= b.y && a.z <= b.z; +} + +static __forceinline__ __device__ +float4 softplus(float4 x) { + return make_float4( + x.x > 20.f ? x.x : logf(1.f + expf(x.x)), + x.y > 20.f ? x.y : logf(1.f + expf(x.y)), + x.z > 20.f ? x.z : logf(1.f + expf(x.z)), + x.w > 20.f ? x.w : logf(1.f + expf(x.w))); +} + +static __forceinline__ __device__ +float softplus(float x) { + // that's a neat trick + return __logf(1.f + __expf(-abs(x))) + max(x, 0.f); +} +static __forceinline__ __device__ +float softplus_grad(float x) { + // that's a neat trick + float expnabsx = __expf(-abs(x)); + return (0.5f - expnabsx / (1.f + expnabsx)) * copysign(1.f, x) + 0.5f; +} + + +static __forceinline__ __device__ +float4 sigmoid(float4 x) { + return make_float4( + 1.f / (1.f + expf(-x.x)), + 1.f / (1.f + expf(-x.y)), + 1.f / (1.f + expf(-x.z)), + 1.f / (1.f + expf(-x.w))); +} + +// perform reduction on warp, then call atomicAdd for only one lane +static __forceinline__ __device__ void fastAtomicAdd(float * ptr, float val) { + for (int offset = 16; offset > 0; offset /= 2) { + val += __shfl_down_sync(0xffffffff, val, offset); + } + + const int laneid = (threadIdx.y * blockDim.x + threadIdx.x) % 32; + if (laneid == 0) { + atomicAdd(ptr, val); + } +} + + +static __forceinline__ __device__ +bool within_bounds_3d(int d, int h, int w, int D, int H, int W) { + return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W; +} + +static __forceinline__ __device__ +void safe_add_3d(float *data, int d, int h, int w, + int sD, int sH, int sW, int D, int H, int W, + float delta) { + if (within_bounds_3d(d, h, w, D, H, W)) { + atomicAdd(data + d * sD + h * sH + w * sW, delta); + } +} + +static __forceinline__ __device__ +void safe_add_3d(float3 *data, int d, int h, int w, + int sD, int sH, int sW, int D, int H, int W, + float3 delta) { + if (within_bounds_3d(d, h, w, D, H, W)) { + atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 3 + 0, delta.x); + atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 3 + 1, delta.y); + atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 3 + 2, delta.z); + } +} + +static __forceinline__ __device__ +void safe_add_3d(float4 *data, int d, int h, int w, + int sD, int sH, int sW, int D, int H, int W, + float4 delta) { + if (within_bounds_3d(d, h, w, D, H, W)) { + atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 4 + 0, delta.x); + atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 4 + 1, delta.y); + atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 4 + 2, delta.z); + atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 4 + 3, delta.w); + } +} + +static __forceinline__ __device__ +float clip_coordinates(float in, int clip_limit) { + return ::min(static_cast(clip_limit - 1), ::max(in, 0.f)); +} + +template +static __forceinline__ __device__ +float clip_coordinates_set_grad(float in, int clip_limit, scalar_t *grad_in) { + if (in < 0.f) { + *grad_in = static_cast(0); + return 0.f; + } else { + float max = static_cast(clip_limit - 1); + if (in > max) { + *grad_in = static_cast(0); + return max; + } else { + *grad_in = static_cast(1); + return in; + } + } +} + +template +static __device__ out_t grid_sample_forward(int C, int inp_D, int inp_H, + int inp_W, float* vals, float3 pos, bool border) { + int inp_sW = 1, inp_sH = inp_W, inp_sD = inp_W * inp_H, inp_sC = inp_W * inp_H * inp_D; + int out_sC = 1; + + // normalize ix, iy, iz from [-1, 1] to [0, inp_W-1] & [0, inp_H-1] & [0, inp_D-1] + float ix = max(-10.f, min(10.f, ((pos.x + 1.f) * 0.5f))) * (inp_W - 1); + float iy = max(-10.f, min(10.f, ((pos.y + 1.f) * 0.5f))) * (inp_H - 1); + float iz = max(-10.f, min(10.f, ((pos.z + 1.f) * 0.5f))) * (inp_D - 1); + + if (border) { + // clip coordinates to image borders + ix = clip_coordinates(ix, inp_W); + iy = clip_coordinates(iy, inp_H); + iz = clip_coordinates(iz, inp_D); + } + + // get corner pixel values from (x, y, z) + // for 4d, we used north-east-south-west + // for 5d, we add top-bottom + int ix_tnw = static_cast(::floor(ix)); + int iy_tnw = static_cast(::floor(iy)); + int iz_tnw = static_cast(::floor(iz)); + + int ix_tne = ix_tnw + 1; + int iy_tne = iy_tnw; + int iz_tne = iz_tnw; + + int ix_tsw = ix_tnw; + int iy_tsw = iy_tnw + 1; + int iz_tsw = iz_tnw; + + int ix_tse = ix_tnw + 1; + int iy_tse = iy_tnw + 1; + int iz_tse = iz_tnw; + + int ix_bnw = ix_tnw; + int iy_bnw = iy_tnw; + int iz_bnw = iz_tnw + 1; + + int ix_bne = ix_tnw + 1; + int iy_bne = iy_tnw; + int iz_bne = iz_tnw + 1; + + int ix_bsw = ix_tnw; + int iy_bsw = iy_tnw + 1; + int iz_bsw = iz_tnw + 1; + + int ix_bse = ix_tnw + 1; + int iy_bse = iy_tnw + 1; + int iz_bse = iz_tnw + 1; + + // get surfaces to each neighbor: + float tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); + float tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); + float tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); + float tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); + float bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); + float bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); + float bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); + float bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); + + out_t result; + //auto inp_ptr_NC = input.data + n * inp_sN; + //auto out_ptr_NCDHW = output.data + n * out_sN + d * out_sD + h * out_sH + w * out_sW; + float * inp_ptr_NC = vals; + float * out_ptr_NCDHW = &result.x; + for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) { + // (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne + // + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse + // + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * bne + // + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * bse + *out_ptr_NCDHW = static_cast(0); + if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw; + } + if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne; + } + if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw; + } + if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse; + } + if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw; + } + if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne; + } + if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw; + } + if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse; + } + } + return result; +} + +template +static __device__ float3 grid_sample_backward(int C, int inp_D, int inp_H, + int inp_W, float* vals, float* grad_vals, float3 pos, out_t grad_out, + bool border) { + int inp_sW = 1, inp_sH = inp_W, inp_sD = inp_W * inp_H, inp_sC = inp_W * inp_H * inp_D; + int gInp_sW = 1, gInp_sH = inp_W, gInp_sD = inp_W * inp_H, gInp_sC = inp_W * inp_H * inp_D; + int gOut_sC = 1; + + // normalize ix, iy, iz from [-1, 1] to [0, inp_W-1] & [0, inp_H-1] & [0, inp_D-1] + float ix = max(-10.f, min(10.f, ((pos.x + 1.f) * 0.5f))) * (inp_W - 1); + float iy = max(-10.f, min(10.f, ((pos.y + 1.f) * 0.5f))) * (inp_H - 1); + float iz = max(-10.f, min(10.f, ((pos.z + 1.f) * 0.5f))) * (inp_D - 1); + + float gix_mult = (inp_W - 1.f) / 2; + float giy_mult = (inp_H - 1.f) / 2; + float giz_mult = (inp_D - 1.f) / 2; + + if (border) { + // clip coordinates to image borders + ix = clip_coordinates_set_grad(ix, inp_W, &gix_mult); + iy = clip_coordinates_set_grad(iy, inp_H, &giy_mult); + iz = clip_coordinates_set_grad(iz, inp_D, &giz_mult); + } + + // get corner pixel values from (x, y, z) + // for 4d, we used north-east-south-west + // for 5d, we add top-bottom + int ix_tnw = static_cast(::floor(ix)); + int iy_tnw = static_cast(::floor(iy)); + int iz_tnw = static_cast(::floor(iz)); + + int ix_tne = ix_tnw + 1; + int iy_tne = iy_tnw; + int iz_tne = iz_tnw; + + int ix_tsw = ix_tnw; + int iy_tsw = iy_tnw + 1; + int iz_tsw = iz_tnw; + + int ix_tse = ix_tnw + 1; + int iy_tse = iy_tnw + 1; + int iz_tse = iz_tnw; + + int ix_bnw = ix_tnw; + int iy_bnw = iy_tnw; + int iz_bnw = iz_tnw + 1; + + int ix_bne = ix_tnw + 1; + int iy_bne = iy_tnw; + int iz_bne = iz_tnw + 1; + + int ix_bsw = ix_tnw; + int iy_bsw = iy_tnw + 1; + int iz_bsw = iz_tnw + 1; + + int ix_bse = ix_tnw + 1; + int iy_bse = iy_tnw + 1; + int iz_bse = iz_tnw + 1; + + // get surfaces to each neighbor: + float tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); + float tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); + float tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); + float tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); + float bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); + float bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); + float bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); + float bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); + + float gix = static_cast(0), giy = static_cast(0), giz = static_cast(0); + //float *gOut_ptr_NCDHW = grad_output.data + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW; + //float *gInp_ptr_NC = grad_input.data + n * gInp_sN; + //float *inp_ptr_NC = input.data + n * inp_sN; + float *gOut_ptr_NCDHW = &grad_out.x; + float *gInp_ptr_NC = grad_vals; + float *inp_ptr_NC = vals; + // calculate bilinear weighted pixel value and set output pixel + for (int c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC += inp_sC) { + float gOut = *gOut_ptr_NCDHW; + + // calculate and set grad_input + safe_add_3d(gInp_ptr_NC, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw * gOut); + safe_add_3d(gInp_ptr_NC, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne * gOut); + safe_add_3d(gInp_ptr_NC, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw * gOut); + safe_add_3d(gInp_ptr_NC, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse * gOut); + safe_add_3d(gInp_ptr_NC, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw * gOut); + safe_add_3d(gInp_ptr_NC, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne * gOut); + safe_add_3d(gInp_ptr_NC, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw * gOut); + safe_add_3d(gInp_ptr_NC, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse * gOut); + + // calculate grad_grid + if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { + float tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW]; + gix -= tnw_val * (iy_bse - iy) * (iz_bse - iz) * gOut; + giy -= tnw_val * (ix_bse - ix) * (iz_bse - iz) * gOut; + giz -= tnw_val * (ix_bse - ix) * (iy_bse - iy) * gOut; + } + if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { + float tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW]; + gix += tne_val * (iy_bsw - iy) * (iz_bsw - iz) * gOut; + giy -= tne_val * (ix - ix_bsw) * (iz_bsw - iz) * gOut; + giz -= tne_val * (ix - ix_bsw) * (iy_bsw - iy) * gOut; + } + if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { + float tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW]; + gix -= tsw_val * (iy - iy_bne) * (iz_bne - iz) * gOut; + giy += tsw_val * (ix_bne - ix) * (iz_bne - iz) * gOut; + giz -= tsw_val * (ix_bne - ix) * (iy - iy_bne) * gOut; + } + if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { + float tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW]; + gix += tse_val * (iy - iy_bnw) * (iz_bnw - iz) * gOut; + giy += tse_val * (ix - ix_bnw) * (iz_bnw - iz) * gOut; + giz -= tse_val * (ix - ix_bnw) * (iy - iy_bnw) * gOut; + } + if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { + float bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW]; + gix -= bnw_val * (iy_tse - iy) * (iz - iz_tse) * gOut; + giy -= bnw_val * (ix_tse - ix) * (iz - iz_tse) * gOut; + giz += bnw_val * (ix_tse - ix) * (iy_tse - iy) * gOut; + } + if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { + float bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW]; + gix += bne_val * (iy_tsw - iy) * (iz - iz_tsw) * gOut; + giy -= bne_val * (ix - ix_tsw) * (iz - iz_tsw) * gOut; + giz += bne_val * (ix - ix_tsw) * (iy_tsw - iy) * gOut; + } + if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { + float bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW]; + gix -= bsw_val * (iy - iy_tne) * (iz - iz_tne) * gOut; + giy += bsw_val * (ix_tne - ix) * (iz - iz_tne) * gOut; + giz += bsw_val * (ix_tne - ix) * (iy - iy_tne) * gOut; + } + if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { + float bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW]; + gix += bse_val * (iy - iy_tnw) * (iz - iz_tnw) * gOut; + giy += bse_val * (ix - ix_tnw) * (iz - iz_tnw) * gOut; + giz += bse_val * (ix - ix_tnw) * (iy - iy_tnw) * gOut; + } + } + + return make_float3(gix_mult * gix, giy_mult * giy, giz_mult * giz); +} + +// this dummy struct necessary because c++ is dumb +template +struct GridSampler { + static __forceinline__ __device__ out_t forward(int C, int inp_D, int inp_H, int inp_W, + float* vals, float3 pos, bool border) { + return grid_sample_forward(C, inp_D, inp_H, inp_W, vals, pos, border); + } + + static __forceinline__ __device__ float3 backward(int C, int inp_D, int inp_H, int inp_W, + float* vals, float* grad_vals, float3 pos, out_t grad_out, bool border) { + return grid_sample_backward(C, inp_D, inp_H, inp_W, vals, grad_vals, pos, grad_out, border); + } +}; + +//template +//__device__ void cswap ( T& a, T& b ) { +// T c(a); a=b; b=c; +//} + +static __forceinline__ __device__ +int within_bounds_3d_ind(int d, int h, int w, int D, int H, int W) { + return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W ? ((d * H) + h) * W + w : -1; +} + +template +static __device__ out_t grid_sample_chlast_forward(int, int inp_D, int inp_H, + int inp_W, float * vals, float3 pos, bool border) { + int inp_sW = 1, inp_sH = inp_W, inp_sD = inp_W * inp_H; + + // normalize ix, iy, iz from [-1, 1] to [0, inp_W-1] & [0, inp_H-1] & [0, inp_D-1] + float ix = max(-100.f, min(100.f, ((pos.x + 1.f) / 2))) * (inp_W - 1); + float iy = max(-100.f, min(100.f, ((pos.y + 1.f) / 2))) * (inp_H - 1); + float iz = max(-100.f, min(100.f, ((pos.z + 1.f) / 2))) * (inp_D - 1); + + if (border) { + // clip coordinates to image borders + ix = clip_coordinates(ix, inp_W); + iy = clip_coordinates(iy, inp_H); + iz = clip_coordinates(iz, inp_D); + } + + // get corner pixel values from (x, y, z) + // for 4d, we used north-east-south-west + // for 5d, we add top-bottom + int ix_tnw = static_cast(::floor(ix)); + int iy_tnw = static_cast(::floor(iy)); + int iz_tnw = static_cast(::floor(iz)); + + int ix_tne = ix_tnw + 1; + int iy_tne = iy_tnw; + int iz_tne = iz_tnw; + + int ix_tsw = ix_tnw; + int iy_tsw = iy_tnw + 1; + int iz_tsw = iz_tnw; + + int ix_tse = ix_tnw + 1; + int iy_tse = iy_tnw + 1; + int iz_tse = iz_tnw; + + int ix_bnw = ix_tnw; + int iy_bnw = iy_tnw; + int iz_bnw = iz_tnw + 1; + + int ix_bne = ix_tnw + 1; + int iy_bne = iy_tnw; + int iz_bne = iz_tnw + 1; + + int ix_bsw = ix_tnw; + int iy_bsw = iy_tnw + 1; + int iz_bsw = iz_tnw + 1; + + int ix_bse = ix_tnw + 1; + int iy_bse = iy_tnw + 1; + int iz_bse = iz_tnw + 1; + + // get surfaces to each neighbor: + float tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); + float tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); + float tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); + float tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); + float bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); + float bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); + float bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); + float bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); + + out_t result; + memset(&result, 0, sizeof(out_t)); + out_t * inp_ptr_NC = (out_t*)vals; + out_t * out_ptr_NCDHW = &result; + { + if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw; + } + if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne; + } + if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw; + } + if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse; + } + if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw; + } + if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne; + } + if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw; + } + if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse; + } + } + + return result; +} + +template +static __device__ float3 grid_sample_chlast_backward(int, int inp_D, int inp_H, + int inp_W, float* vals, float* grad_vals, float3 pos, out_t grad_out, + bool border) { + int inp_sW = 1, inp_sH = inp_W, inp_sD = inp_W * inp_H; + int gInp_sW = 1, gInp_sH = inp_W, gInp_sD = inp_W * inp_H; + + // normalize ix, iy, iz from [-1, 1] to [0, inp_W-1] & [0, inp_H-1] & [0, inp_D-1] + float ix = max(-100.f, min(100.f, ((pos.x + 1.f) / 2))) * (inp_W - 1); + float iy = max(-100.f, min(100.f, ((pos.y + 1.f) / 2))) * (inp_H - 1); + float iz = max(-100.f, min(100.f, ((pos.z + 1.f) / 2))) * (inp_D - 1); + + float gix_mult = (inp_W - 1.f) / 2; + float giy_mult = (inp_H - 1.f) / 2; + float giz_mult = (inp_D - 1.f) / 2; + + if (border) { + // clip coordinates to image borders + ix = clip_coordinates_set_grad(ix, inp_W, &gix_mult); + iy = clip_coordinates_set_grad(iy, inp_H, &giy_mult); + iz = clip_coordinates_set_grad(iz, inp_D, &giz_mult); + } + + // get corner pixel values from (x, y, z) + // for 4d, we used north-east-south-west + // for 5d, we add top-bottom + int ix_tnw = static_cast(::floor(ix)); + int iy_tnw = static_cast(::floor(iy)); + int iz_tnw = static_cast(::floor(iz)); + + int ix_tne = ix_tnw + 1; + int iy_tne = iy_tnw; + int iz_tne = iz_tnw; + + int ix_tsw = ix_tnw; + int iy_tsw = iy_tnw + 1; + int iz_tsw = iz_tnw; + + int ix_tse = ix_tnw + 1; + int iy_tse = iy_tnw + 1; + int iz_tse = iz_tnw; + + int ix_bnw = ix_tnw; + int iy_bnw = iy_tnw; + int iz_bnw = iz_tnw + 1; + + int ix_bne = ix_tnw + 1; + int iy_bne = iy_tnw; + int iz_bne = iz_tnw + 1; + + int ix_bsw = ix_tnw; + int iy_bsw = iy_tnw + 1; + int iz_bsw = iz_tnw + 1; + + int ix_bse = ix_tnw + 1; + int iy_bse = iy_tnw + 1; + int iz_bse = iz_tnw + 1; + + // get surfaces to each neighbor: + float tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); + float tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); + float tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); + float tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); + float bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); + float bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); + float bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); + float bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); + + float gix = static_cast(0), giy = static_cast(0), giz = static_cast(0); + out_t *gOut_ptr_NCDHW = &grad_out; + out_t *gInp_ptr_NC = (out_t*)grad_vals; + out_t *inp_ptr_NC = (out_t*)vals; + + // calculate bilinear weighted pixel value and set output pixel + { + out_t gOut = *gOut_ptr_NCDHW; + + // calculate and set grad_input + safe_add_3d(gInp_ptr_NC, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw * gOut); + safe_add_3d(gInp_ptr_NC, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne * gOut); + safe_add_3d(gInp_ptr_NC, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw * gOut); + safe_add_3d(gInp_ptr_NC, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse * gOut); + safe_add_3d(gInp_ptr_NC, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw * gOut); + safe_add_3d(gInp_ptr_NC, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne * gOut); + safe_add_3d(gInp_ptr_NC, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw * gOut); + safe_add_3d(gInp_ptr_NC, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse * gOut); + + // calculate grad_grid + if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { + out_t tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW]; + gix -= (iy_bse - iy) * (iz_bse - iz) * dot(tnw_val, gOut); + giy -= (ix_bse - ix) * (iz_bse - iz) * dot(tnw_val, gOut); + giz -= (ix_bse - ix) * (iy_bse - iy) * dot(tnw_val, gOut); + } + if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { + out_t tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW]; + gix += (iy_bsw - iy) * (iz_bsw - iz) * dot(tne_val, gOut); + giy -= (ix - ix_bsw) * (iz_bsw - iz) * dot(tne_val, gOut); + giz -= (ix - ix_bsw) * (iy_bsw - iy) * dot(tne_val, gOut); + } + if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { + out_t tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW]; + gix -= (iy - iy_bne) * (iz_bne - iz) * dot(tsw_val, gOut); + giy += (ix_bne - ix) * (iz_bne - iz) * dot(tsw_val, gOut); + giz -= (ix_bne - ix) * (iy - iy_bne) * dot(tsw_val, gOut); + } + if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { + out_t tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW]; + gix += (iy - iy_bnw) * (iz_bnw - iz) * dot(tse_val, gOut); + giy += (ix - ix_bnw) * (iz_bnw - iz) * dot(tse_val, gOut); + giz -= (ix - ix_bnw) * (iy - iy_bnw) * dot(tse_val, gOut); + } + if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { + out_t bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW]; + gix -= (iy_tse - iy) * (iz - iz_tse) * dot(bnw_val, gOut); + giy -= (ix_tse - ix) * (iz - iz_tse) * dot(bnw_val, gOut); + giz += (ix_tse - ix) * (iy_tse - iy) * dot(bnw_val, gOut); + } + if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { + out_t bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW]; + gix += (iy_tsw - iy) * (iz - iz_tsw) * dot(bne_val, gOut); + giy -= (ix - ix_tsw) * (iz - iz_tsw) * dot(bne_val, gOut); + giz += (ix - ix_tsw) * (iy_tsw - iy) * dot(bne_val, gOut); + } + if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { + out_t bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW]; + gix -= (iy - iy_tne) * (iz - iz_tne) * dot(bsw_val, gOut); + giy += (ix_tne - ix) * (iz - iz_tne) * dot(bsw_val, gOut); + giz += (ix_tne - ix) * (iy - iy_tne) * dot(bsw_val, gOut); + } + if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { + out_t bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW]; + gix += (iy - iy_tnw) * (iz - iz_tnw) * dot(bse_val, gOut); + giy += (ix - ix_tnw) * (iz - iz_tnw) * dot(bse_val, gOut); + giz += (ix - ix_tnw) * (iy - iy_tnw) * dot(bse_val, gOut); + } + } + + return make_float3(gix_mult * gix, giy_mult * giy, giz_mult * giz); +} + +template +struct GridSamplerChlast { + static __forceinline__ __device__ out_t forward(int C, int inp_D, int inp_H, int inp_W, + float* vals, float3 pos, bool border) { + return grid_sample_chlast_forward(C, inp_D, inp_H, inp_W, vals, pos, border); + } + + static __forceinline__ __device__ float3 backward(int C, int inp_D, int inp_H, int inp_W, + float* vals, float* grad_vals, float3 pos, out_t grad_out, bool border) { + return grid_sample_chlast_backward(C, inp_D, inp_H, inp_W, vals, grad_vals, pos, grad_out, border); + } +}; + + +inline __host__ __device__ float min_component(float3 a) { + return fminf(fminf(a.x,a.y),a.z); +} + +inline __host__ __device__ float max_component(float3 a) { + return fmaxf(fmaxf(a.x,a.y),a.z); +} + + inline __host__ __device__ float3 abs(float3 a) { + return make_float3(abs(a.x), abs(a.y), abs(a.z)); +} + +__forceinline__ __device__ bool ray_aabb_hit(float3 p0, float3 p1, float3 raypos, float3 raydir) { + float3 t0 = (p0 - raypos) / raydir; + float3 t1 = (p1 - raypos) / raydir; + float3 tmin = fminf(t0,t1), tmax = fmaxf(t0,t1); + + return max_component(tmin) <= min_component(tmax); +} + +__forceinline__ __device__ bool ray_aabb_hit_ird(float3 p0, float3 p1, float3 raypos, float3 ird) { + float3 t0 = (p0 - raypos) * ird; + float3 t1 = (p1 - raypos) * ird; + float3 tmin = fminf(t0,t1), tmax = fmaxf(t0,t1); + + return max_component(tmin) <= min_component(tmax); + +} +__forceinline__ __device__ void ray_aabb_hit_ird_tminmax(float3 p0, float3 p1, + float3 raypos, float3 ird, float &otmin, float &otmax) { + float3 t0 = (p0 - raypos) * ird; + float3 t1 = (p1 - raypos) * ird; + float3 tmin = fminf(t0,t1), tmax = fmaxf(t0,t1); + tmin = fminf(t0,t1); + tmax = fmaxf(t0,t1); + otmin = max_component(tmin); + otmax = min_component(tmax); +} + +inline __device__ bool aabb_intersect(float3 p0, float3 p1, float3 r0, float3 rd, float &tmin, float &tmax) { + float tymin, tymax, tzmin, tzmax; + const float3 bounds[2] = {p0, p1}; + float3 ird = 1.0f/rd; + int sx = (ird.x<0) ? 1 : 0; + int sy = (ird.y<0) ? 1 : 0; + int sz = (ird.z<0) ? 1 : 0; + tmin = (bounds[sx].x - r0.x) * ird.x; + tmax = (bounds[1-sx].x - r0.x) * ird.x; + tymin = (bounds[sy].y - r0.y) * ird.y; + tymax = (bounds[1-sy].y - r0.y) * ird.y; + + if ((tmin > tymax) || (tymin > tmax)) + return false; + if (tymin > tmin) + tmin = tymin; + if (tymax < tmax) + tmax = tymax; + + tzmin = (bounds[sz].z - r0.z) * ird.z; + tzmax = (bounds[1-sz].z - r0.z) * ird.z; + + if ((tmin > tzmax) || (tzmin > tmax)) + return false; + if (tzmin > tmin) + tmin = tzmin; + if (tzmax < tmax) + tmax = tzmax; + + return true; +} + +template +static __forceinline__ __device__ void ray_subset_fixedbvh( + unsigned warpmask, + int K, + float3 raypos, + float3 raydir, + float2 tminmax, + float2 &rtminmax, + int * sortedobjid, + int2 * nodechildren, + float3 * nodeaabb, + const typename PrimTransfT::Data & primtransf_data, + int *hitboxes, + int & num) { + float3 iraydir = 1.0f/raydir; + int stack[64]; + int* stack_ptr = stack; + *stack_ptr++ = -1; + int node = 0; + do { + // check if we're in a leaf + if (node >= (K - 1)) { + { + int k = node - (K - 1); + + float3 r0, rd; + PrimTransfT::forward2(primtransf_data, k, raypos, raydir, r0, rd); + + float3 ird = 1.0f/rd; + float3 t0 = (-1.f - r0) * ird; + float3 t1 = (1.f - r0) * ird; + float3 tmin = fminf(t0,t1), tmax = fmaxf(t0,t1); + + float trmin = max_component(tmin); + float trmax = min_component(tmax); + + bool intersection = trmin <= trmax; + + if (intersection) { + // hit + rtminmax.x = fminf(rtminmax.x, trmin); + rtminmax.y = fmaxf(rtminmax.y, trmax); + } + + if (sync) { + intersection = __any_sync(warpmask, intersection); + } + + if (intersection) { + if (sortboxes) { + if (num < maxhitboxes) { + int j = num - 1; + while (j >= 0 && hitboxes[j] > k) { + hitboxes[j + 1] = hitboxes[j]; + j = j - 1; + } + hitboxes[j + 1] = k; + num++; + } + } else { + if (num < maxhitboxes) { + hitboxes[num++] = k; + } + } + } + } + + node = *--stack_ptr; + } else { + int2 children = make_int2(node * 2 + 1, node * 2 + 2); + + // check if we're in each child's bbox + float3 * nodeaabb_ptr = nodeaabb + children.x * 2; + bool traverse_l = ray_aabb_hit_ird(nodeaabb_ptr[0], nodeaabb_ptr[1], raypos, iraydir); + bool traverse_r = ray_aabb_hit_ird(nodeaabb_ptr[2], nodeaabb_ptr[3], raypos, iraydir); + + if (sync) { + traverse_l = __any_sync(warpmask, traverse_l); + traverse_r = __any_sync(warpmask, traverse_r); + } + + // update stack + if (!traverse_l && !traverse_r) { + node = *--stack_ptr; + } else { + node = traverse_l ? children.x : children.y; + if (traverse_l && traverse_r) { + *stack_ptr++ = children.y; + } + } + + if (sync) { + __syncwarp(warpmask); + } + } + } while (node != -1); +} + +template +struct RaySubsetFixedBVH { + static __forceinline__ __device__ void forward( + unsigned warpmask, + int K, + float3 raypos, + float3 raydir, + float2 tminmax, + float2 &rtminmax, + int * sortedobjid, + int2 * nodechildren, + float3 * nodeaabb, + const typename PrimTransfT::Data & primtransf_data, + int *hitboxes, + int & num) { + ray_subset_fixedbvh( + warpmask, K, raypos, raydir, tminmax, rtminmax, + sortedobjid, nodechildren, nodeaabb, primtransf_data, hitboxes, num); + } +}; + +#endif diff --git a/dva/mvp/extensions/utils/helper_math.h b/dva/mvp/extensions/utils/helper_math.h new file mode 100644 index 0000000000000000000000000000000000000000..c9c07c3e74bbd1f469740f95c45a2eae49322e99 --- /dev/null +++ b/dva/mvp/extensions/utils/helper_math.h @@ -0,0 +1,1453 @@ +/** + * Copyright 1993-2013 NVIDIA Corporation. All rights reserved. + * + * Please refer to the NVIDIA end user license agreement (EULA) associated + * with this source code for terms and conditions that govern your use of + * this software. Any use, reproduction, disclosure, or distribution of + * this software and related documentation outside the terms of the EULA + * is strictly prohibited. + * + */ + +/* + * This file implements common mathematical operations on vector types + * (float3, float4 etc.) since these are not provided as standard by CUDA. + * + * The syntax is modeled on the Cg standard library. + * + * This is part of the Helper library includes + * + * Thanks to Linh Hah for additions and fixes. + */ + +#ifndef HELPER_MATH_H +#define HELPER_MATH_H + +#include "cuda_runtime.h" + +typedef unsigned int uint; +typedef unsigned short ushort; + +#ifndef EXIT_WAIVED +#define EXIT_WAIVED 2 +#endif + +#ifndef __CUDACC__ +#include + +//////////////////////////////////////////////////////////////////////////////// +// host implementations of CUDA functions +//////////////////////////////////////////////////////////////////////////////// + +inline float fminf(float a, float b) +{ + return a < b ? a : b; +} + +inline float fmaxf(float a, float b) +{ + return a > b ? a : b; +} + +inline int max(int a, int b) +{ + return a > b ? a : b; +} + +inline int min(int a, int b) +{ + return a < b ? a : b; +} + +inline float rsqrtf(float x) +{ + return 1.0f / sqrtf(x); +} +#endif + +//////////////////////////////////////////////////////////////////////////////// +// constructors +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 make_float2(float s) +{ + return make_float2(s, s); +} +inline __host__ __device__ float2 make_float2(float3 a) +{ + return make_float2(a.x, a.y); +} +inline __host__ __device__ float2 make_float2(int2 a) +{ + return make_float2(float(a.x), float(a.y)); +} +inline __host__ __device__ float2 make_float2(uint2 a) +{ + return make_float2(float(a.x), float(a.y)); +} + +inline __host__ __device__ int2 make_int2(int s) +{ + return make_int2(s, s); +} +inline __host__ __device__ int2 make_int2(int3 a) +{ + return make_int2(a.x, a.y); +} +inline __host__ __device__ int2 make_int2(uint2 a) +{ + return make_int2(int(a.x), int(a.y)); +} +inline __host__ __device__ int2 make_int2(float2 a) +{ + return make_int2(int(a.x), int(a.y)); +} + +inline __host__ __device__ uint2 make_uint2(uint s) +{ + return make_uint2(s, s); +} +inline __host__ __device__ uint2 make_uint2(uint3 a) +{ + return make_uint2(a.x, a.y); +} +inline __host__ __device__ uint2 make_uint2(int2 a) +{ + return make_uint2(uint(a.x), uint(a.y)); +} + +inline __host__ __device__ float3 make_float3(float s) +{ + return make_float3(s, s, s); +} +inline __host__ __device__ float3 make_float3(float2 a) +{ + return make_float3(a.x, a.y, 0.0f); +} +inline __host__ __device__ float3 make_float3(float2 a, float s) +{ + return make_float3(a.x, a.y, s); +} +inline __host__ __device__ float3 make_float3(float4 a) +{ + return make_float3(a.x, a.y, a.z); +} +inline __host__ __device__ float3 make_float3(int3 a) +{ + return make_float3(float(a.x), float(a.y), float(a.z)); +} +inline __host__ __device__ float3 make_float3(uint3 a) +{ + return make_float3(float(a.x), float(a.y), float(a.z)); +} + +inline __host__ __device__ int3 make_int3(int s) +{ + return make_int3(s, s, s); +} +inline __host__ __device__ int3 make_int3(int2 a) +{ + return make_int3(a.x, a.y, 0); +} +inline __host__ __device__ int3 make_int3(int2 a, int s) +{ + return make_int3(a.x, a.y, s); +} +inline __host__ __device__ int3 make_int3(uint3 a) +{ + return make_int3(int(a.x), int(a.y), int(a.z)); +} +inline __host__ __device__ int3 make_int3(float3 a) +{ + return make_int3(int(a.x), int(a.y), int(a.z)); +} + +inline __host__ __device__ uint3 make_uint3(uint s) +{ + return make_uint3(s, s, s); +} +inline __host__ __device__ uint3 make_uint3(uint2 a) +{ + return make_uint3(a.x, a.y, 0); +} +inline __host__ __device__ uint3 make_uint3(uint2 a, uint s) +{ + return make_uint3(a.x, a.y, s); +} +inline __host__ __device__ uint3 make_uint3(uint4 a) +{ + return make_uint3(a.x, a.y, a.z); +} +inline __host__ __device__ uint3 make_uint3(int3 a) +{ + return make_uint3(uint(a.x), uint(a.y), uint(a.z)); +} + +inline __host__ __device__ float4 make_float4(float s) +{ + return make_float4(s, s, s, s); +} +inline __host__ __device__ float4 make_float4(float3 a) +{ + return make_float4(a.x, a.y, a.z, 0.0f); +} +inline __host__ __device__ float4 make_float4(float3 a, float w) +{ + return make_float4(a.x, a.y, a.z, w); +} +inline __host__ __device__ float4 make_float4(int4 a) +{ + return make_float4(float(a.x), float(a.y), float(a.z), float(a.w)); +} +inline __host__ __device__ float4 make_float4(uint4 a) +{ + return make_float4(float(a.x), float(a.y), float(a.z), float(a.w)); +} + +inline __host__ __device__ int4 make_int4(int s) +{ + return make_int4(s, s, s, s); +} +inline __host__ __device__ int4 make_int4(int3 a) +{ + return make_int4(a.x, a.y, a.z, 0); +} +inline __host__ __device__ int4 make_int4(int3 a, int w) +{ + return make_int4(a.x, a.y, a.z, w); +} +inline __host__ __device__ int4 make_int4(uint4 a) +{ + return make_int4(int(a.x), int(a.y), int(a.z), int(a.w)); +} +inline __host__ __device__ int4 make_int4(float4 a) +{ + return make_int4(int(a.x), int(a.y), int(a.z), int(a.w)); +} + + +inline __host__ __device__ uint4 make_uint4(uint s) +{ + return make_uint4(s, s, s, s); +} +inline __host__ __device__ uint4 make_uint4(uint3 a) +{ + return make_uint4(a.x, a.y, a.z, 0); +} +inline __host__ __device__ uint4 make_uint4(uint3 a, uint w) +{ + return make_uint4(a.x, a.y, a.z, w); +} +inline __host__ __device__ uint4 make_uint4(int4 a) +{ + return make_uint4(uint(a.x), uint(a.y), uint(a.z), uint(a.w)); +} + +//////////////////////////////////////////////////////////////////////////////// +// negate +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 operator-(float2 &a) +{ + return make_float2(-a.x, -a.y); +} +inline __host__ __device__ int2 operator-(int2 &a) +{ + return make_int2(-a.x, -a.y); +} +inline __host__ __device__ float3 operator-(float3 &a) +{ + return make_float3(-a.x, -a.y, -a.z); +} +inline __host__ __device__ int3 operator-(int3 &a) +{ + return make_int3(-a.x, -a.y, -a.z); +} +inline __host__ __device__ float4 operator-(float4 &a) +{ + return make_float4(-a.x, -a.y, -a.z, -a.w); +} +inline __host__ __device__ int4 operator-(int4 &a) +{ + return make_int4(-a.x, -a.y, -a.z, -a.w); +} + +//////////////////////////////////////////////////////////////////////////////// +// addition +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 operator+(float2 a, float2 b) +{ + return make_float2(a.x + b.x, a.y + b.y); +} +inline __host__ __device__ void operator+=(float2 &a, float2 b) +{ + a.x += b.x; + a.y += b.y; +} +inline __host__ __device__ float2 operator+(float2 a, float b) +{ + return make_float2(a.x + b, a.y + b); +} +inline __host__ __device__ float2 operator+(float b, float2 a) +{ + return make_float2(a.x + b, a.y + b); +} +inline __host__ __device__ void operator+=(float2 &a, float b) +{ + a.x += b; + a.y += b; +} + +inline __host__ __device__ int2 operator+(int2 a, int2 b) +{ + return make_int2(a.x + b.x, a.y + b.y); +} +inline __host__ __device__ void operator+=(int2 &a, int2 b) +{ + a.x += b.x; + a.y += b.y; +} +inline __host__ __device__ int2 operator+(int2 a, int b) +{ + return make_int2(a.x + b, a.y + b); +} +inline __host__ __device__ int2 operator+(int b, int2 a) +{ + return make_int2(a.x + b, a.y + b); +} +inline __host__ __device__ void operator+=(int2 &a, int b) +{ + a.x += b; + a.y += b; +} + +inline __host__ __device__ uint2 operator+(uint2 a, uint2 b) +{ + return make_uint2(a.x + b.x, a.y + b.y); +} +inline __host__ __device__ void operator+=(uint2 &a, uint2 b) +{ + a.x += b.x; + a.y += b.y; +} +inline __host__ __device__ uint2 operator+(uint2 a, uint b) +{ + return make_uint2(a.x + b, a.y + b); +} +inline __host__ __device__ uint2 operator+(uint b, uint2 a) +{ + return make_uint2(a.x + b, a.y + b); +} +inline __host__ __device__ void operator+=(uint2 &a, uint b) +{ + a.x += b; + a.y += b; +} + + +inline __host__ __device__ float3 operator+(float3 a, float3 b) +{ + return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); +} +inline __host__ __device__ void operator+=(float3 &a, float3 b) +{ + a.x += b.x; + a.y += b.y; + a.z += b.z; +} +inline __host__ __device__ float3 operator+(float3 a, float b) +{ + return make_float3(a.x + b, a.y + b, a.z + b); +} +inline __host__ __device__ void operator+=(float3 &a, float b) +{ + a.x += b; + a.y += b; + a.z += b; +} + +inline __host__ __device__ int3 operator+(int3 a, int3 b) +{ + return make_int3(a.x + b.x, a.y + b.y, a.z + b.z); +} +inline __host__ __device__ void operator+=(int3 &a, int3 b) +{ + a.x += b.x; + a.y += b.y; + a.z += b.z; +} +inline __host__ __device__ int3 operator+(int3 a, int b) +{ + return make_int3(a.x + b, a.y + b, a.z + b); +} +inline __host__ __device__ void operator+=(int3 &a, int b) +{ + a.x += b; + a.y += b; + a.z += b; +} + +inline __host__ __device__ uint3 operator+(uint3 a, uint3 b) +{ + return make_uint3(a.x + b.x, a.y + b.y, a.z + b.z); +} +inline __host__ __device__ void operator+=(uint3 &a, uint3 b) +{ + a.x += b.x; + a.y += b.y; + a.z += b.z; +} +inline __host__ __device__ uint3 operator+(uint3 a, uint b) +{ + return make_uint3(a.x + b, a.y + b, a.z + b); +} +inline __host__ __device__ void operator+=(uint3 &a, uint b) +{ + a.x += b; + a.y += b; + a.z += b; +} + +inline __host__ __device__ int3 operator+(int b, int3 a) +{ + return make_int3(a.x + b, a.y + b, a.z + b); +} +inline __host__ __device__ uint3 operator+(uint b, uint3 a) +{ + return make_uint3(a.x + b, a.y + b, a.z + b); +} +inline __host__ __device__ float3 operator+(float b, float3 a) +{ + return make_float3(a.x + b, a.y + b, a.z + b); +} + +inline __host__ __device__ float4 operator+(float4 a, float4 b) +{ + return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); +} +inline __host__ __device__ void operator+=(float4 &a, float4 b) +{ + a.x += b.x; + a.y += b.y; + a.z += b.z; + a.w += b.w; +} +inline __host__ __device__ float4 operator+(float4 a, float b) +{ + return make_float4(a.x + b, a.y + b, a.z + b, a.w + b); +} +inline __host__ __device__ float4 operator+(float b, float4 a) +{ + return make_float4(a.x + b, a.y + b, a.z + b, a.w + b); +} +inline __host__ __device__ void operator+=(float4 &a, float b) +{ + a.x += b; + a.y += b; + a.z += b; + a.w += b; +} + +inline __host__ __device__ int4 operator+(int4 a, int4 b) +{ + return make_int4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); +} +inline __host__ __device__ void operator+=(int4 &a, int4 b) +{ + a.x += b.x; + a.y += b.y; + a.z += b.z; + a.w += b.w; +} +inline __host__ __device__ int4 operator+(int4 a, int b) +{ + return make_int4(a.x + b, a.y + b, a.z + b, a.w + b); +} +inline __host__ __device__ int4 operator+(int b, int4 a) +{ + return make_int4(a.x + b, a.y + b, a.z + b, a.w + b); +} +inline __host__ __device__ void operator+=(int4 &a, int b) +{ + a.x += b; + a.y += b; + a.z += b; + a.w += b; +} + +inline __host__ __device__ uint4 operator+(uint4 a, uint4 b) +{ + return make_uint4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); +} +inline __host__ __device__ void operator+=(uint4 &a, uint4 b) +{ + a.x += b.x; + a.y += b.y; + a.z += b.z; + a.w += b.w; +} +inline __host__ __device__ uint4 operator+(uint4 a, uint b) +{ + return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b); +} +inline __host__ __device__ uint4 operator+(uint b, uint4 a) +{ + return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b); +} +inline __host__ __device__ void operator+=(uint4 &a, uint b) +{ + a.x += b; + a.y += b; + a.z += b; + a.w += b; +} + +//////////////////////////////////////////////////////////////////////////////// +// subtract +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 operator-(float2 a, float2 b) +{ + return make_float2(a.x - b.x, a.y - b.y); +} +inline __host__ __device__ void operator-=(float2 &a, float2 b) +{ + a.x -= b.x; + a.y -= b.y; +} +inline __host__ __device__ float2 operator-(float2 a, float b) +{ + return make_float2(a.x - b, a.y - b); +} +inline __host__ __device__ float2 operator-(float b, float2 a) +{ + return make_float2(b - a.x, b - a.y); +} +inline __host__ __device__ void operator-=(float2 &a, float b) +{ + a.x -= b; + a.y -= b; +} + +inline __host__ __device__ int2 operator-(int2 a, int2 b) +{ + return make_int2(a.x - b.x, a.y - b.y); +} +inline __host__ __device__ void operator-=(int2 &a, int2 b) +{ + a.x -= b.x; + a.y -= b.y; +} +inline __host__ __device__ int2 operator-(int2 a, int b) +{ + return make_int2(a.x - b, a.y - b); +} +inline __host__ __device__ int2 operator-(int b, int2 a) +{ + return make_int2(b - a.x, b - a.y); +} +inline __host__ __device__ void operator-=(int2 &a, int b) +{ + a.x -= b; + a.y -= b; +} + +inline __host__ __device__ uint2 operator-(uint2 a, uint2 b) +{ + return make_uint2(a.x - b.x, a.y - b.y); +} +inline __host__ __device__ void operator-=(uint2 &a, uint2 b) +{ + a.x -= b.x; + a.y -= b.y; +} +inline __host__ __device__ uint2 operator-(uint2 a, uint b) +{ + return make_uint2(a.x - b, a.y - b); +} +inline __host__ __device__ uint2 operator-(uint b, uint2 a) +{ + return make_uint2(b - a.x, b - a.y); +} +inline __host__ __device__ void operator-=(uint2 &a, uint b) +{ + a.x -= b; + a.y -= b; +} + +inline __host__ __device__ float3 operator-(float3 a, float3 b) +{ + return make_float3(a.x - b.x, a.y - b.y, a.z - b.z); +} +inline __host__ __device__ void operator-=(float3 &a, float3 b) +{ + a.x -= b.x; + a.y -= b.y; + a.z -= b.z; +} +inline __host__ __device__ float3 operator-(float3 a, float b) +{ + return make_float3(a.x - b, a.y - b, a.z - b); +} +inline __host__ __device__ float3 operator-(float b, float3 a) +{ + return make_float3(b - a.x, b - a.y, b - a.z); +} +inline __host__ __device__ void operator-=(float3 &a, float b) +{ + a.x -= b; + a.y -= b; + a.z -= b; +} + +inline __host__ __device__ int3 operator-(int3 a, int3 b) +{ + return make_int3(a.x - b.x, a.y - b.y, a.z - b.z); +} +inline __host__ __device__ void operator-=(int3 &a, int3 b) +{ + a.x -= b.x; + a.y -= b.y; + a.z -= b.z; +} +inline __host__ __device__ int3 operator-(int3 a, int b) +{ + return make_int3(a.x - b, a.y - b, a.z - b); +} +inline __host__ __device__ int3 operator-(int b, int3 a) +{ + return make_int3(b - a.x, b - a.y, b - a.z); +} +inline __host__ __device__ void operator-=(int3 &a, int b) +{ + a.x -= b; + a.y -= b; + a.z -= b; +} + +inline __host__ __device__ uint3 operator-(uint3 a, uint3 b) +{ + return make_uint3(a.x - b.x, a.y - b.y, a.z - b.z); +} +inline __host__ __device__ void operator-=(uint3 &a, uint3 b) +{ + a.x -= b.x; + a.y -= b.y; + a.z -= b.z; +} +inline __host__ __device__ uint3 operator-(uint3 a, uint b) +{ + return make_uint3(a.x - b, a.y - b, a.z - b); +} +inline __host__ __device__ uint3 operator-(uint b, uint3 a) +{ + return make_uint3(b - a.x, b - a.y, b - a.z); +} +inline __host__ __device__ void operator-=(uint3 &a, uint b) +{ + a.x -= b; + a.y -= b; + a.z -= b; +} + +inline __host__ __device__ float4 operator-(float4 a, float4 b) +{ + return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); +} +inline __host__ __device__ void operator-=(float4 &a, float4 b) +{ + a.x -= b.x; + a.y -= b.y; + a.z -= b.z; + a.w -= b.w; +} +inline __host__ __device__ float4 operator-(float4 a, float b) +{ + return make_float4(a.x - b, a.y - b, a.z - b, a.w - b); +} +inline __host__ __device__ void operator-=(float4 &a, float b) +{ + a.x -= b; + a.y -= b; + a.z -= b; + a.w -= b; +} + +inline __host__ __device__ int4 operator-(int4 a, int4 b) +{ + return make_int4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); +} +inline __host__ __device__ void operator-=(int4 &a, int4 b) +{ + a.x -= b.x; + a.y -= b.y; + a.z -= b.z; + a.w -= b.w; +} +inline __host__ __device__ int4 operator-(int4 a, int b) +{ + return make_int4(a.x - b, a.y - b, a.z - b, a.w - b); +} +inline __host__ __device__ int4 operator-(int b, int4 a) +{ + return make_int4(b - a.x, b - a.y, b - a.z, b - a.w); +} +inline __host__ __device__ void operator-=(int4 &a, int b) +{ + a.x -= b; + a.y -= b; + a.z -= b; + a.w -= b; +} + +inline __host__ __device__ uint4 operator-(uint4 a, uint4 b) +{ + return make_uint4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); +} +inline __host__ __device__ void operator-=(uint4 &a, uint4 b) +{ + a.x -= b.x; + a.y -= b.y; + a.z -= b.z; + a.w -= b.w; +} +inline __host__ __device__ uint4 operator-(uint4 a, uint b) +{ + return make_uint4(a.x - b, a.y - b, a.z - b, a.w - b); +} +inline __host__ __device__ uint4 operator-(uint b, uint4 a) +{ + return make_uint4(b - a.x, b - a.y, b - a.z, b - a.w); +} +inline __host__ __device__ void operator-=(uint4 &a, uint b) +{ + a.x -= b; + a.y -= b; + a.z -= b; + a.w -= b; +} + +//////////////////////////////////////////////////////////////////////////////// +// multiply +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 operator*(float2 a, float2 b) +{ + return make_float2(a.x * b.x, a.y * b.y); +} +inline __host__ __device__ void operator*=(float2 &a, float2 b) +{ + a.x *= b.x; + a.y *= b.y; +} +inline __host__ __device__ float2 operator*(float2 a, float b) +{ + return make_float2(a.x * b, a.y * b); +} +inline __host__ __device__ float2 operator*(float b, float2 a) +{ + return make_float2(b * a.x, b * a.y); +} +inline __host__ __device__ void operator*=(float2 &a, float b) +{ + a.x *= b; + a.y *= b; +} + +inline __host__ __device__ int2 operator*(int2 a, int2 b) +{ + return make_int2(a.x * b.x, a.y * b.y); +} +inline __host__ __device__ void operator*=(int2 &a, int2 b) +{ + a.x *= b.x; + a.y *= b.y; +} +inline __host__ __device__ int2 operator*(int2 a, int b) +{ + return make_int2(a.x * b, a.y * b); +} +inline __host__ __device__ int2 operator*(int b, int2 a) +{ + return make_int2(b * a.x, b * a.y); +} +inline __host__ __device__ void operator*=(int2 &a, int b) +{ + a.x *= b; + a.y *= b; +} + +inline __host__ __device__ uint2 operator*(uint2 a, uint2 b) +{ + return make_uint2(a.x * b.x, a.y * b.y); +} +inline __host__ __device__ void operator*=(uint2 &a, uint2 b) +{ + a.x *= b.x; + a.y *= b.y; +} +inline __host__ __device__ uint2 operator*(uint2 a, uint b) +{ + return make_uint2(a.x * b, a.y * b); +} +inline __host__ __device__ uint2 operator*(uint b, uint2 a) +{ + return make_uint2(b * a.x, b * a.y); +} +inline __host__ __device__ void operator*=(uint2 &a, uint b) +{ + a.x *= b; + a.y *= b; +} + +inline __host__ __device__ float3 operator*(float3 a, float3 b) +{ + return make_float3(a.x * b.x, a.y * b.y, a.z * b.z); +} +inline __host__ __device__ void operator*=(float3 &a, float3 b) +{ + a.x *= b.x; + a.y *= b.y; + a.z *= b.z; +} +inline __host__ __device__ float3 operator*(float3 a, float b) +{ + return make_float3(a.x * b, a.y * b, a.z * b); +} +inline __host__ __device__ float3 operator*(float b, float3 a) +{ + return make_float3(b * a.x, b * a.y, b * a.z); +} +inline __host__ __device__ void operator*=(float3 &a, float b) +{ + a.x *= b; + a.y *= b; + a.z *= b; +} + +inline __host__ __device__ int3 operator*(int3 a, int3 b) +{ + return make_int3(a.x * b.x, a.y * b.y, a.z * b.z); +} +inline __host__ __device__ void operator*=(int3 &a, int3 b) +{ + a.x *= b.x; + a.y *= b.y; + a.z *= b.z; +} +inline __host__ __device__ int3 operator*(int3 a, int b) +{ + return make_int3(a.x * b, a.y * b, a.z * b); +} +inline __host__ __device__ int3 operator*(int b, int3 a) +{ + return make_int3(b * a.x, b * a.y, b * a.z); +} +inline __host__ __device__ void operator*=(int3 &a, int b) +{ + a.x *= b; + a.y *= b; + a.z *= b; +} + +inline __host__ __device__ uint3 operator*(uint3 a, uint3 b) +{ + return make_uint3(a.x * b.x, a.y * b.y, a.z * b.z); +} +inline __host__ __device__ void operator*=(uint3 &a, uint3 b) +{ + a.x *= b.x; + a.y *= b.y; + a.z *= b.z; +} +inline __host__ __device__ uint3 operator*(uint3 a, uint b) +{ + return make_uint3(a.x * b, a.y * b, a.z * b); +} +inline __host__ __device__ uint3 operator*(uint b, uint3 a) +{ + return make_uint3(b * a.x, b * a.y, b * a.z); +} +inline __host__ __device__ void operator*=(uint3 &a, uint b) +{ + a.x *= b; + a.y *= b; + a.z *= b; +} + +inline __host__ __device__ float4 operator*(float4 a, float4 b) +{ + return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); +} +inline __host__ __device__ void operator*=(float4 &a, float4 b) +{ + a.x *= b.x; + a.y *= b.y; + a.z *= b.z; + a.w *= b.w; +} +inline __host__ __device__ float4 operator*(float4 a, float b) +{ + return make_float4(a.x * b, a.y * b, a.z * b, a.w * b); +} +inline __host__ __device__ float4 operator*(float b, float4 a) +{ + return make_float4(b * a.x, b * a.y, b * a.z, b * a.w); +} +inline __host__ __device__ void operator*=(float4 &a, float b) +{ + a.x *= b; + a.y *= b; + a.z *= b; + a.w *= b; +} + +inline __host__ __device__ int4 operator*(int4 a, int4 b) +{ + return make_int4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); +} +inline __host__ __device__ void operator*=(int4 &a, int4 b) +{ + a.x *= b.x; + a.y *= b.y; + a.z *= b.z; + a.w *= b.w; +} +inline __host__ __device__ int4 operator*(int4 a, int b) +{ + return make_int4(a.x * b, a.y * b, a.z * b, a.w * b); +} +inline __host__ __device__ int4 operator*(int b, int4 a) +{ + return make_int4(b * a.x, b * a.y, b * a.z, b * a.w); +} +inline __host__ __device__ void operator*=(int4 &a, int b) +{ + a.x *= b; + a.y *= b; + a.z *= b; + a.w *= b; +} + +inline __host__ __device__ uint4 operator*(uint4 a, uint4 b) +{ + return make_uint4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); +} +inline __host__ __device__ void operator*=(uint4 &a, uint4 b) +{ + a.x *= b.x; + a.y *= b.y; + a.z *= b.z; + a.w *= b.w; +} +inline __host__ __device__ uint4 operator*(uint4 a, uint b) +{ + return make_uint4(a.x * b, a.y * b, a.z * b, a.w * b); +} +inline __host__ __device__ uint4 operator*(uint b, uint4 a) +{ + return make_uint4(b * a.x, b * a.y, b * a.z, b * a.w); +} +inline __host__ __device__ void operator*=(uint4 &a, uint b) +{ + a.x *= b; + a.y *= b; + a.z *= b; + a.w *= b; +} + +//////////////////////////////////////////////////////////////////////////////// +// divide +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 operator/(float2 a, float2 b) +{ + return make_float2(a.x / b.x, a.y / b.y); +} +inline __host__ __device__ void operator/=(float2 &a, float2 b) +{ + a.x /= b.x; + a.y /= b.y; +} +inline __host__ __device__ float2 operator/(float2 a, float b) +{ + return make_float2(a.x / b, a.y / b); +} +inline __host__ __device__ void operator/=(float2 &a, float b) +{ + a.x /= b; + a.y /= b; +} +inline __host__ __device__ float2 operator/(float b, float2 a) +{ + return make_float2(b / a.x, b / a.y); +} + +inline __host__ __device__ float3 operator/(float3 a, float3 b) +{ + return make_float3(a.x / b.x, a.y / b.y, a.z / b.z); +} +inline __host__ __device__ void operator/=(float3 &a, float3 b) +{ + a.x /= b.x; + a.y /= b.y; + a.z /= b.z; +} +inline __host__ __device__ float3 operator/(float3 a, float b) +{ + return make_float3(a.x / b, a.y / b, a.z / b); +} +inline __host__ __device__ void operator/=(float3 &a, float b) +{ + a.x /= b; + a.y /= b; + a.z /= b; +} +inline __host__ __device__ float3 operator/(float b, float3 a) +{ + return make_float3(b / a.x, b / a.y, b / a.z); +} + +inline __host__ __device__ float4 operator/(float4 a, float4 b) +{ + return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w); +} +inline __host__ __device__ void operator/=(float4 &a, float4 b) +{ + a.x /= b.x; + a.y /= b.y; + a.z /= b.z; + a.w /= b.w; +} +inline __host__ __device__ float4 operator/(float4 a, float b) +{ + return make_float4(a.x / b, a.y / b, a.z / b, a.w / b); +} +inline __host__ __device__ void operator/=(float4 &a, float b) +{ + a.x /= b; + a.y /= b; + a.z /= b; + a.w /= b; +} +inline __host__ __device__ float4 operator/(float b, float4 a) +{ + return make_float4(b / a.x, b / a.y, b / a.z, b / a.w); +} + +//////////////////////////////////////////////////////////////////////////////// +// min +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 fminf(float2 a, float2 b) +{ + return make_float2(fminf(a.x,b.x), fminf(a.y,b.y)); +} +inline __host__ __device__ float3 fminf(float3 a, float3 b) +{ + return make_float3(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z)); +} +inline __host__ __device__ float4 fminf(float4 a, float4 b) +{ + return make_float4(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z), fminf(a.w,b.w)); +} + +inline __host__ __device__ int2 min(int2 a, int2 b) +{ + return make_int2(min(a.x,b.x), min(a.y,b.y)); +} +inline __host__ __device__ int3 min(int3 a, int3 b) +{ + return make_int3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z)); +} +inline __host__ __device__ int4 min(int4 a, int4 b) +{ + return make_int4(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z), min(a.w,b.w)); +} + +inline __host__ __device__ uint2 min(uint2 a, uint2 b) +{ + return make_uint2(min(a.x,b.x), min(a.y,b.y)); +} +inline __host__ __device__ uint3 min(uint3 a, uint3 b) +{ + return make_uint3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z)); +} +inline __host__ __device__ uint4 min(uint4 a, uint4 b) +{ + return make_uint4(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z), min(a.w,b.w)); +} + +//////////////////////////////////////////////////////////////////////////////// +// max +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 fmaxf(float2 a, float2 b) +{ + return make_float2(fmaxf(a.x,b.x), fmaxf(a.y,b.y)); +} +inline __host__ __device__ float3 fmaxf(float3 a, float3 b) +{ + return make_float3(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z)); +} +inline __host__ __device__ float4 fmaxf(float4 a, float4 b) +{ + return make_float4(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z), fmaxf(a.w,b.w)); +} + +inline __host__ __device__ int2 max(int2 a, int2 b) +{ + return make_int2(max(a.x,b.x), max(a.y,b.y)); +} +inline __host__ __device__ int3 max(int3 a, int3 b) +{ + return make_int3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z)); +} +inline __host__ __device__ int4 max(int4 a, int4 b) +{ + return make_int4(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z), max(a.w,b.w)); +} + +inline __host__ __device__ uint2 max(uint2 a, uint2 b) +{ + return make_uint2(max(a.x,b.x), max(a.y,b.y)); +} +inline __host__ __device__ uint3 max(uint3 a, uint3 b) +{ + return make_uint3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z)); +} +inline __host__ __device__ uint4 max(uint4 a, uint4 b) +{ + return make_uint4(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z), max(a.w,b.w)); +} + +//////////////////////////////////////////////////////////////////////////////// +// lerp +// - linear interpolation between a and b, based on value t in [0, 1] range +//////////////////////////////////////////////////////////////////////////////// + +inline __device__ __host__ float lerp(float a, float b, float t) +{ + return a + t*(b-a); +} +inline __device__ __host__ float2 lerp(float2 a, float2 b, float t) +{ + return a + t*(b-a); +} +inline __device__ __host__ float3 lerp(float3 a, float3 b, float t) +{ + return a + t*(b-a); +} +inline __device__ __host__ float4 lerp(float4 a, float4 b, float t) +{ + return a + t*(b-a); +} + +//////////////////////////////////////////////////////////////////////////////// +// clamp +// - clamp the value v to be in the range [a, b] +//////////////////////////////////////////////////////////////////////////////// + +inline __device__ __host__ float clamp(float f, float a, float b) +{ + return fmaxf(a, fminf(f, b)); +} +inline __device__ __host__ int clamp(int f, int a, int b) +{ + return max(a, min(f, b)); +} +inline __device__ __host__ uint clamp(uint f, uint a, uint b) +{ + return max(a, min(f, b)); +} + +inline __device__ __host__ float2 clamp(float2 v, float a, float b) +{ + return make_float2(clamp(v.x, a, b), clamp(v.y, a, b)); +} +inline __device__ __host__ float2 clamp(float2 v, float2 a, float2 b) +{ + return make_float2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y)); +} +inline __device__ __host__ float3 clamp(float3 v, float a, float b) +{ + return make_float3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b)); +} +inline __device__ __host__ float3 clamp(float3 v, float3 a, float3 b) +{ + return make_float3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z)); +} +inline __device__ __host__ float4 clamp(float4 v, float a, float b) +{ + return make_float4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b)); +} +inline __device__ __host__ float4 clamp(float4 v, float4 a, float4 b) +{ + return make_float4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w)); +} + +inline __device__ __host__ int2 clamp(int2 v, int a, int b) +{ + return make_int2(clamp(v.x, a, b), clamp(v.y, a, b)); +} +inline __device__ __host__ int2 clamp(int2 v, int2 a, int2 b) +{ + return make_int2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y)); +} +inline __device__ __host__ int3 clamp(int3 v, int a, int b) +{ + return make_int3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b)); +} +inline __device__ __host__ int3 clamp(int3 v, int3 a, int3 b) +{ + return make_int3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z)); +} +inline __device__ __host__ int4 clamp(int4 v, int a, int b) +{ + return make_int4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b)); +} +inline __device__ __host__ int4 clamp(int4 v, int4 a, int4 b) +{ + return make_int4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w)); +} + +inline __device__ __host__ uint2 clamp(uint2 v, uint a, uint b) +{ + return make_uint2(clamp(v.x, a, b), clamp(v.y, a, b)); +} +inline __device__ __host__ uint2 clamp(uint2 v, uint2 a, uint2 b) +{ + return make_uint2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y)); +} +inline __device__ __host__ uint3 clamp(uint3 v, uint a, uint b) +{ + return make_uint3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b)); +} +inline __device__ __host__ uint3 clamp(uint3 v, uint3 a, uint3 b) +{ + return make_uint3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z)); +} +inline __device__ __host__ uint4 clamp(uint4 v, uint a, uint b) +{ + return make_uint4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b)); +} +inline __device__ __host__ uint4 clamp(uint4 v, uint4 a, uint4 b) +{ + return make_uint4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w)); +} + +//////////////////////////////////////////////////////////////////////////////// +// dot product +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float dot(float2 a, float2 b) +{ + return a.x * b.x + a.y * b.y; +} +inline __host__ __device__ float dot(float3 a, float3 b) +{ + return a.x * b.x + a.y * b.y + a.z * b.z; +} +inline __host__ __device__ float dot(float4 a, float4 b) +{ + return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w; +} + +inline __host__ __device__ int dot(int2 a, int2 b) +{ + return a.x * b.x + a.y * b.y; +} +inline __host__ __device__ int dot(int3 a, int3 b) +{ + return a.x * b.x + a.y * b.y + a.z * b.z; +} +inline __host__ __device__ int dot(int4 a, int4 b) +{ + return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w; +} + +inline __host__ __device__ uint dot(uint2 a, uint2 b) +{ + return a.x * b.x + a.y * b.y; +} +inline __host__ __device__ uint dot(uint3 a, uint3 b) +{ + return a.x * b.x + a.y * b.y + a.z * b.z; +} +inline __host__ __device__ uint dot(uint4 a, uint4 b) +{ + return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w; +} + +//////////////////////////////////////////////////////////////////////////////// +// length +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float length(float2 v) +{ + return sqrtf(dot(v, v)); +} +inline __host__ __device__ float length(float3 v) +{ + return sqrtf(dot(v, v)); +} +inline __host__ __device__ float length(float4 v) +{ + return sqrtf(dot(v, v)); +} + +//////////////////////////////////////////////////////////////////////////////// +// normalize +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 normalize(float2 v) +{ + float invLen = rsqrtf(dot(v, v)); + return v * invLen; +} +inline __host__ __device__ float3 normalize(float3 v) +{ + float invLen = rsqrtf(dot(v, v)); + return v * invLen; +} +inline __host__ __device__ float4 normalize(float4 v) +{ + float invLen = rsqrtf(dot(v, v)); + return v * invLen; +} + +//////////////////////////////////////////////////////////////////////////////// +// floor +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 floorf(float2 v) +{ + return make_float2(floorf(v.x), floorf(v.y)); +} +inline __host__ __device__ float3 floorf(float3 v) +{ + return make_float3(floorf(v.x), floorf(v.y), floorf(v.z)); +} +inline __host__ __device__ float4 floorf(float4 v) +{ + return make_float4(floorf(v.x), floorf(v.y), floorf(v.z), floorf(v.w)); +} + +//////////////////////////////////////////////////////////////////////////////// +// frac - returns the fractional portion of a scalar or each vector component +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float fracf(float v) +{ + return v - floorf(v); +} +inline __host__ __device__ float2 fracf(float2 v) +{ + return make_float2(fracf(v.x), fracf(v.y)); +} +inline __host__ __device__ float3 fracf(float3 v) +{ + return make_float3(fracf(v.x), fracf(v.y), fracf(v.z)); +} +inline __host__ __device__ float4 fracf(float4 v) +{ + return make_float4(fracf(v.x), fracf(v.y), fracf(v.z), fracf(v.w)); +} + +//////////////////////////////////////////////////////////////////////////////// +// fmod +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 fmodf(float2 a, float2 b) +{ + return make_float2(fmodf(a.x, b.x), fmodf(a.y, b.y)); +} +inline __host__ __device__ float3 fmodf(float3 a, float3 b) +{ + return make_float3(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z)); +} +inline __host__ __device__ float4 fmodf(float4 a, float4 b) +{ + return make_float4(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z), fmodf(a.w, b.w)); +} + +//////////////////////////////////////////////////////////////////////////////// +// absolute value +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float2 fabs(float2 v) +{ + return make_float2(fabs(v.x), fabs(v.y)); +} +inline __host__ __device__ float3 fabs(float3 v) +{ + return make_float3(fabs(v.x), fabs(v.y), fabs(v.z)); +} +inline __host__ __device__ float4 fabs(float4 v) +{ + return make_float4(fabs(v.x), fabs(v.y), fabs(v.z), fabs(v.w)); +} + +inline __host__ __device__ int2 abs(int2 v) +{ + return make_int2(abs(v.x), abs(v.y)); +} +inline __host__ __device__ int3 abs(int3 v) +{ + return make_int3(abs(v.x), abs(v.y), abs(v.z)); +} +inline __host__ __device__ int4 abs(int4 v) +{ + return make_int4(abs(v.x), abs(v.y), abs(v.z), abs(v.w)); +} + +//////////////////////////////////////////////////////////////////////////////// +// reflect +// - returns reflection of incident ray I around surface normal N +// - N should be normalized, reflected vector's length is equal to length of I +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float3 reflect(float3 i, float3 n) +{ + return i - 2.0f * n * dot(n,i); +} + +//////////////////////////////////////////////////////////////////////////////// +// cross product +//////////////////////////////////////////////////////////////////////////////// + +inline __host__ __device__ float3 cross(float3 a, float3 b) +{ + return make_float3(a.y*b.z - a.z*b.y, a.z*b.x - a.x*b.z, a.x*b.y - a.y*b.x); +} + +//////////////////////////////////////////////////////////////////////////////// +// smoothstep +// - returns 0 if x < a +// - returns 1 if x > b +// - otherwise returns smooth interpolation between 0 and 1 based on x +//////////////////////////////////////////////////////////////////////////////// + +inline __device__ __host__ float smoothstep(float a, float b, float x) +{ + float y = clamp((x - a) / (b - a), 0.0f, 1.0f); + return (y*y*(3.0f - (2.0f*y))); +} +inline __device__ __host__ float2 smoothstep(float2 a, float2 b, float2 x) +{ + float2 y = clamp((x - a) / (b - a), 0.0f, 1.0f); + return (y*y*(make_float2(3.0f) - (make_float2(2.0f)*y))); +} +inline __device__ __host__ float3 smoothstep(float3 a, float3 b, float3 x) +{ + float3 y = clamp((x - a) / (b - a), 0.0f, 1.0f); + return (y*y*(make_float3(3.0f) - (make_float3(2.0f)*y))); +} +inline __device__ __host__ float4 smoothstep(float4 a, float4 b, float4 x) +{ + float4 y = clamp((x - a) / (b - a), 0.0f, 1.0f); + return (y*y*(make_float4(3.0f) - (make_float4(2.0f)*y))); +} + +#endif diff --git a/dva/mvp/extensions/utils/makefile b/dva/mvp/extensions/utils/makefile new file mode 100644 index 0000000000000000000000000000000000000000..4a1f97a7a9c9320562641ad94b7ada28ef5c2777 --- /dev/null +++ b/dva/mvp/extensions/utils/makefile @@ -0,0 +1,2 @@ +all: + python setup.py build_ext --inplace diff --git a/dva/mvp/extensions/utils/setup.py b/dva/mvp/extensions/utils/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..317233d0a08255a7c91d33266f3c6c936df0da9d --- /dev/null +++ b/dva/mvp/extensions/utils/setup.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from setuptools import setup + +from torch.utils.cpp_extension import CUDAExtension, BuildExtension + +if __name__ == "__main__": + import torch + setup( + name="utils", + ext_modules=[ + CUDAExtension( + "utilslib", + sources=["utils.cpp", "utils_kernel.cu"], + extra_compile_args={ + "nvcc": [ + "-arch=sm_70", + "-std=c++14", + "-lineinfo", + ] + } + ) + ], + cmdclass={"build_ext": BuildExtension} + ) diff --git a/dva/mvp/extensions/utils/utils.cpp b/dva/mvp/extensions/utils/utils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..38230a513b97cb8ca67bd400093327529601f1c8 --- /dev/null +++ b/dva/mvp/extensions/utils/utils.cpp @@ -0,0 +1,137 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +#include + +void compute_raydirs_forward_cuda( + int N, int H, int W, + float * viewposim, + float * viewrotim, + float * focalim, + float * princptim, + float * pixelcoordsim, + float volradius, + float * raypos, + float * raydir, + float * tminmax, + cudaStream_t stream); + +void compute_raydirs_backward_cuda( + int N, int H, int W, + float * viewposim, + float * viewrotim, + float * focalim, + float * princptim, + float * pixelcoordsim, + float volradius, + float * raypos, + float * raydir, + float * tminmax, + float * grad_viewposim, + float * grad_viewrotim, + float * grad_focalim, + float * grad_princptim, + cudaStream_t stream); + +#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA((x)); CHECK_CONTIGUOUS((x)) + +std::vector compute_raydirs_forward( + torch::Tensor viewposim, + torch::Tensor viewrotim, + torch::Tensor focalim, + torch::Tensor princptim, + torch::optional pixelcoordsim, + int W, int H, + float volradius, + torch::Tensor rayposim, + torch::Tensor raydirim, + torch::Tensor tminmaxim) { + CHECK_INPUT(viewposim); + CHECK_INPUT(viewrotim); + CHECK_INPUT(focalim); + CHECK_INPUT(princptim); + if (pixelcoordsim) { CHECK_INPUT(*pixelcoordsim); } + CHECK_INPUT(rayposim); + CHECK_INPUT(raydirim); + CHECK_INPUT(tminmaxim); + + int N = viewposim.size(0); + assert(!pixelcoordsim || (pixelcoordsim.size(1) == H && pixelcoordsim.size(2) == W)); + + compute_raydirs_forward_cuda(N, H, W, + reinterpret_cast(viewposim.data_ptr()), + reinterpret_cast(viewrotim.data_ptr()), + reinterpret_cast(focalim.data_ptr()), + reinterpret_cast(princptim.data_ptr()), + pixelcoordsim ? reinterpret_cast(pixelcoordsim->data_ptr()) : nullptr, + volradius, + reinterpret_cast(rayposim.data_ptr()), + reinterpret_cast(raydirim.data_ptr()), + reinterpret_cast(tminmaxim.data_ptr()), + 0); + + return {}; +} + +std::vector compute_raydirs_backward( + torch::Tensor viewposim, + torch::Tensor viewrotim, + torch::Tensor focalim, + torch::Tensor princptim, + torch::optional pixelcoordsim, + int W, int H, + float volradius, + torch::Tensor rayposim, + torch::Tensor raydirim, + torch::Tensor tminmaxim, + torch::Tensor grad_viewpos, + torch::Tensor grad_viewrot, + torch::Tensor grad_focal, + torch::Tensor grad_princpt) { + CHECK_INPUT(viewposim); + CHECK_INPUT(viewrotim); + CHECK_INPUT(focalim); + CHECK_INPUT(princptim); + if (pixelcoordsim) { CHECK_INPUT(*pixelcoordsim); } + CHECK_INPUT(rayposim); + CHECK_INPUT(raydirim); + CHECK_INPUT(tminmaxim); + CHECK_INPUT(grad_viewpos); + CHECK_INPUT(grad_viewrot); + CHECK_INPUT(grad_focal); + CHECK_INPUT(grad_princpt); + + int N = viewposim.size(0); + assert(!pixelcoordsim || (pixelcoordsim.size(1) == H && pixelcoordsim.size(2) == W)); + + compute_raydirs_backward_cuda(N, H, W, + reinterpret_cast(viewposim.data_ptr()), + reinterpret_cast(viewrotim.data_ptr()), + reinterpret_cast(focalim.data_ptr()), + reinterpret_cast(princptim.data_ptr()), + pixelcoordsim ? reinterpret_cast(pixelcoordsim->data_ptr()) : nullptr, + volradius, + reinterpret_cast(rayposim.data_ptr()), + reinterpret_cast(raydirim.data_ptr()), + reinterpret_cast(tminmaxim.data_ptr()), + reinterpret_cast(grad_viewpos.data_ptr()), + reinterpret_cast(grad_viewrot.data_ptr()), + reinterpret_cast(grad_focal.data_ptr()), + reinterpret_cast(grad_princpt.data_ptr()), + 0); + + return {}; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("compute_raydirs_forward", &compute_raydirs_forward, "raydirs forward (CUDA)"); + m.def("compute_raydirs_backward", &compute_raydirs_backward, "raydirs backward (CUDA)"); +} diff --git a/dva/mvp/extensions/utils/utils.py b/dva/mvp/extensions/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ba8f948fadd7764941a8c897f398af4df85be5cc --- /dev/null +++ b/dva/mvp/extensions/utils/utils.py @@ -0,0 +1,211 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import time + +import torch +import torch.nn as nn +from torch.autograd import Function +import torch.nn.functional as F + +try: + from . import utilslib +except: + import utilslib + +class ComputeRaydirs(Function): + @staticmethod + def forward(self, viewpos, viewrot, focal, princpt, pixelcoords, volradius): + for tensor in [viewpos, viewrot, focal, princpt, pixelcoords]: + assert tensor.is_contiguous() + + N = viewpos.size(0) + if isinstance(pixelcoords, tuple): + W, H = pixelcoords + pixelcoords = None + else: + H = pixelcoords.size(1) + W = pixelcoords.size(2) + + raypos = torch.empty((N, H, W, 3), device=viewpos.device) + raydirs = torch.empty((N, H, W, 3), device=viewpos.device) + tminmax = torch.empty((N, H, W, 2), device=viewpos.device) + utilslib.compute_raydirs_forward(viewpos, viewrot, focal, princpt, + pixelcoords, W, H, volradius, raypos, raydirs, tminmax) + + return raypos, raydirs, tminmax + + @staticmethod + def backward(self, grad_raydirs, grad_tminmax): + return None, None, None, None, None, None + +def compute_raydirs(viewpos, viewrot, focal, princpt, pixelcoords, volradius): + raypos, raydirs, tminmax = ComputeRaydirs.apply(viewpos, viewrot, focal, princpt, pixelcoords, volradius) + return raypos, raydirs, tminmax + +class Rodrigues(nn.Module): + def __init__(self): + super(Rodrigues, self).__init__() + + def forward(self, rvec): + theta = torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=1)) + rvec = rvec / theta[:, None] + costh = torch.cos(theta) + sinth = torch.sin(theta) + return torch.stack(( + rvec[:, 0] ** 2 + (1. - rvec[:, 0] ** 2) * costh, + rvec[:, 0] * rvec[:, 1] * (1. - costh) - rvec[:, 2] * sinth, + rvec[:, 0] * rvec[:, 2] * (1. - costh) + rvec[:, 1] * sinth, + + rvec[:, 0] * rvec[:, 1] * (1. - costh) + rvec[:, 2] * sinth, + rvec[:, 1] ** 2 + (1. - rvec[:, 1] ** 2) * costh, + rvec[:, 1] * rvec[:, 2] * (1. - costh) - rvec[:, 0] * sinth, + + rvec[:, 0] * rvec[:, 2] * (1. - costh) - rvec[:, 1] * sinth, + rvec[:, 1] * rvec[:, 2] * (1. - costh) + rvec[:, 0] * sinth, + rvec[:, 2] ** 2 + (1. - rvec[:, 2] ** 2) * costh), dim=1).view(-1, 3, 3) + +def gradcheck(): + N = 2 + H = 64 + W = 64 + k3 = 4 + K = k3*k3*k3 + + M = 32 + volradius = 1. + + # generate random inputs + torch.manual_seed(1113) + + rodrigues = Rodrigues() + + _viewpos = torch.tensor([[-0.0, 0.0, -4.] for n in range(N)], device="cuda") + torch.randn(N, 3, device="cuda") * 0.1 + viewrvec = torch.randn(N, 3, device="cuda") * 0.01 + _viewrot = rodrigues(viewrvec) + + _focal = torch.tensor([[W*4.0, W*4.0] for n in range(N)], device="cuda") + _princpt = torch.tensor([[W*0.5, H*0.5] for n in range(N)], device="cuda") + pixely, pixelx = torch.meshgrid(torch.arange(H, device="cuda").float(), torch.arange(W, device="cuda").float()) + _pixelcoords = torch.stack([pixelx, pixely], dim=-1)[None, :, :, :].repeat(N, 1, 1, 1) + + _viewpos = _viewpos.contiguous().detach().clone() + _viewpos.requires_grad = True + _viewrot = _viewrot.contiguous().detach().clone() + _viewrot.requires_grad = True + _focal = _focal.contiguous().detach().clone() + _focal.requires_grad = True + _princpt = _princpt.contiguous().detach().clone() + _princpt.requires_grad = True + _pixelcoords = _pixelcoords.contiguous().detach().clone() + _pixelcoords.requires_grad = True + + max_len = 6.0 + _stepsize = max_len / 15.5 + + params = [_viewpos, _viewrot, _focal, _princpt] + paramnames = ["viewpos", "viewrot", "focal", "princpt"] + + ########################### run pytorch version ########################### + + viewpos = _viewpos + viewrot = _viewrot + focal = _focal + princpt = _princpt + pixelcoords = _pixelcoords + + raypos = viewpos[:, None, None, :].repeat(1, H, W, 1) + + raydir = (pixelcoords - princpt[:, None, None, :]) / focal[:, None, None, :] + raydir = torch.cat([raydir, torch.ones_like(raydir[:, :, :, 0:1])], dim=-1) + raydir = torch.sum(viewrot[:, None, None, :, :] * raydir[:, :, :, :, None], dim=-2) + raydir = raydir / torch.sqrt(torch.sum(raydir ** 2, dim=-1, keepdim=True)) + + t1 = (-1. - viewpos[:, None, None, :]) / raydir + t2 = ( 1. - viewpos[:, None, None, :]) / raydir + tmin = torch.max(torch.min(t1[..., 0], t2[..., 0]), + torch.max(torch.min(t1[..., 1], t2[..., 1]), + torch.min(t1[..., 2], t2[..., 2]))).clamp(min=0.) + tmax = torch.min(torch.max(t1[..., 0], t2[..., 0]), + torch.min(torch.max(t1[..., 1], t2[..., 1]), + torch.max(t1[..., 2], t2[..., 2]))) + + tminmax = torch.stack([tmin, tmax], dim=-1) + + sample0 = raydir + + torch.cuda.synchronize() + time1 = time.time() + + sample0.backward(torch.ones_like(sample0)) + + torch.cuda.synchronize() + time2 = time.time() + + grads0 = [p.grad.detach().clone() if p.grad is not None else None for p in params] + + for p in params: + if p.grad is not None: + p.grad.detach_() + p.grad.zero_() + + ############################## run cuda version ########################### + + viewpos = _viewpos + viewrot = _viewrot + focal = _focal + princpt = _princpt + pixelcoords = _pixelcoords + + niter = 1 + + for p in params: + if p.grad is not None: + p.grad.detach_() + p.grad.zero_() + t0 = time.time() + torch.cuda.synchronize() + + sample1 = compute_raydirs(viewpos, viewrot, focal, princpt, pixelcoords, volradius)[1] + + t1 = time.time() + torch.cuda.synchronize() + + print("-----------------------------------------------------------------") + print("{:>10} {:>10} {:>10} {:>10} {:>10} {:>10}".format("", "maxabsdiff", "dp", "index", "py", "cuda")) + ind = torch.argmax(torch.abs(sample0 - sample1)) + print("{:<10} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format( + "fwd", + torch.max(torch.abs(sample0 - sample1)).item(), + (torch.sum(sample0 * sample1) / torch.sqrt(torch.sum(sample0 * sample0) * torch.sum(sample1 * sample1))).item(), + ind.item(), + sample0.view(-1)[ind].item(), + sample1.view(-1)[ind].item())) + + sample1.backward(torch.ones_like(sample1), retain_graph=True) + + torch.cuda.synchronize() + t2 = time.time() + + + print("{:<10} {:10.5} {:10.5} {:10.5}".format("time", tf / niter, tb / niter, (tf + tb) / niter)) + grads1 = [p.grad.detach().clone() if p.grad is not None else None for p in params] + + ############# compare results ############# + + for p, g0, g1 in zip(paramnames, grads0, grads1): + ind = torch.argmax(torch.abs(g0 - g1)) + print("{:<10} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format( + p, + torch.max(torch.abs(g0 - g1)).item(), + (torch.sum(g0 * g1) / torch.sqrt(torch.sum(g0 * g0) * torch.sum(g1 * g1))).item(), + ind.item(), + g0.view(-1)[ind].item(), + g1.view(-1)[ind].item())) + +if __name__ == "__main__": + gradcheck() diff --git a/dva/mvp/extensions/utils/utils_kernel.cu b/dva/mvp/extensions/utils/utils_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..37f4a030066394fcfe743cdb3d17d914101ffcc0 --- /dev/null +++ b/dva/mvp/extensions/utils/utils_kernel.cu @@ -0,0 +1,174 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include + +#include "helper_math.h" + +__global__ void compute_raydirs_forward_kernel( + int N, int H, int W, + float3 * viewposim, + float3 * viewrotim, + float2 * focalim, + float2 * princptim, + float2 * pixelcoordsim, + float volradius, + float3 * rayposim, + float3 * raydirim, + float2 * tminmaxim + ) { + bool validthread = false; + int w, h, n; + w = blockIdx.x * blockDim.x + threadIdx.x; + h = (blockIdx.y * blockDim.y + threadIdx.y)%H; + n = (blockIdx.y * blockDim.y + threadIdx.y)/H; + validthread = (w < W) && (h < H) && (n>>( + N, H, W, + reinterpret_cast(viewposim), + reinterpret_cast(viewrotim), + reinterpret_cast(focalim), + reinterpret_cast(princptim), + reinterpret_cast(pixelcoordsim), + volradius, + reinterpret_cast(rayposim), + reinterpret_cast(raydirim), + reinterpret_cast(tminmaxim)); +} + +void compute_raydirs_backward_cuda( + int N, int H, int W, + float * viewposim, + float * viewrotim, + float * focalim, + float * princptim, + float * pixelcoordsim, + float volradius, + float * rayposim, + float * raydirim, + float * tminmaxim, + float * grad_viewposim, + float * grad_viewrotim, + float * grad_focalim, + float * grad_princptim, + cudaStream_t stream) { + int blocksizex = 16; + int blocksizey = 16; + dim3 blocksize(blocksizex, blocksizey); + dim3 gridsize; + gridsize = dim3( + (W + blocksize.x - 1) / blocksize.x, + (N*H + blocksize.y - 1) / blocksize.y); + + auto fn = compute_raydirs_backward_kernel; + fn<<>>( + N, H, W, + reinterpret_cast(viewposim), + reinterpret_cast(viewrotim), + reinterpret_cast(focalim), + reinterpret_cast(princptim), + reinterpret_cast(pixelcoordsim), + volradius, + reinterpret_cast(rayposim), + reinterpret_cast(raydirim), + reinterpret_cast(tminmaxim), + reinterpret_cast(grad_viewposim), + reinterpret_cast(grad_viewrotim), + reinterpret_cast(grad_focalim), + reinterpret_cast(grad_princptim)); +} diff --git a/dva/mvp/models/bg/lap.py b/dva/mvp/models/bg/lap.py new file mode 100644 index 0000000000000000000000000000000000000000..f553b2d042068ac0a06137ec50d4953aa74fc646 --- /dev/null +++ b/dva/mvp/models/bg/lap.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import models.utils + +class ImageMod(nn.Module): + def __init__(self, width, height, depth, buf=False): + super(ImageMod, self).__init__() + + if buf: + self.register_buffer("image", torch.randn(1, 3, depth, height, width) * 0.001, persistent=False) + else: + self.image = nn.Parameter(torch.randn(1, 3, depth, height, width) * 0.001) + + def forward(self, samplecoords): + image = self.image.expand(samplecoords.size(0), -1, -1, -1, -1) + return F.grid_sample(image, samplecoords, align_corners=True) + +class LapImage(nn.Module): + def __init__(self, width, height, depth, levels, startlevel=0, buftop=False, align_corners=True): + super(LapImage, self).__init__() + + self.width : int = int(width) + self.height : int = int(height) + self.levels = levels + self.startlevel = startlevel + self.align_corners = align_corners + + self.pyr = nn.ModuleList( + [ImageMod(self.width // 2 ** i, self.height // 2 ** i, depth) + for i in list(range(startlevel, levels - 1))[::-1]] + + ([ImageMod(self.width, self.height, depth, buf=True)] if buftop else [])) + self.pyr0 = ImageMod(self.width // 2 ** (levels - 1), self.height // 2 ** (levels - 1), depth) + + def forward(self, samplecoords): + image = self.pyr0(samplecoords) + + for i, layer in enumerate(self.pyr): + image = image + layer(samplecoords) + + return image + +class BGModel(nn.Module): + def __init__(self, width, height, allcameras, bgdict=True, trainstart=0, + levels=5, startlevel=0, buftop=False, align_corners=True): + super(BGModel, self).__init__() + + self.allcameras = allcameras + self.trainstart = trainstart + + if trainstart > -1: + self.lap = LapImage(width, height, len(allcameras), levels=levels, + startlevel=startlevel, buftop=buftop, + align_corners=align_corners) + + def forward( + self, + bg : Optional[torch.Tensor]=None, + camindex : Optional[torch.Tensor]=None, + raypos : Optional[torch.Tensor]=None, + rayposend : Optional[torch.Tensor]=None, + raydir : Optional[torch.Tensor]=None, + samplecoords : Optional[torch.Tensor]=None, + trainiter : float=-1): + if self.trainstart > -1 and trainiter >= self.trainstart and camindex is not None: + assert samplecoords is not None + assert camindex is not None + + samplecoordscam = torch.cat([ + samplecoords[:, None, :, :, :], # [B, 1, H, W, 2] + ((camindex[:, None, None, None, None] * 2.) / (len(self.allcameras) - 1.) - 1.) + .expand(-1, -1, samplecoords.size(1), samplecoords.size(2), -1)], + dim=-1) # [B, 1, H, W, 3] + lap = self.lap(samplecoordscam)[:, :, 0, :, :] + else: + lap = None + + if lap is None: + return None + else: + return F.softplus(lap) diff --git a/dva/mvp/models/bg/mlp.py b/dva/mvp/models/bg/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..7e92410af9c2763a334776c5b76c373c5b16f570 --- /dev/null +++ b/dva/mvp/models/bg/mlp.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from models.utils import BufferDict, Conv2dELR + +class BGModel(nn.Module): + def __init__(self, width, height, allcameras, bgdict=True, demod=True, trainstart=0): + super(BGModel, self).__init__() + + self.allcameras = allcameras + self.trainstart = trainstart + + if bgdict: + self.bg = BufferDict({k: torch.ones(3, height, width) for k in allcameras}) + else: + self.bg = None + + if trainstart > -1: + self.mlp1 = nn.Sequential( + Conv2dELR(60+24, 256, 1, 1, 0, demod="demod" if demod else None), nn.LeakyReLU(0.2), + Conv2dELR( 256, 256, 1, 1, 0, demod="demod" if demod else None), nn.LeakyReLU(0.2), + Conv2dELR( 256, 256, 1, 1, 0, demod="demod" if demod else None), nn.LeakyReLU(0.2), + Conv2dELR( 256, 256, 1, 1, 0, demod="demod" if demod else None), nn.LeakyReLU(0.2), + Conv2dELR( 256, 256, 1, 1, 0, demod="demod" if demod else None)) + + self.mlp2 = nn.Sequential( + Conv2dELR(60+24+256, 256, 1, 1, 0, demod="demod" if demod else None), nn.LeakyReLU(0.2), + Conv2dELR( 256, 256, 1, 1, 0, demod="demod" if demod else None), nn.LeakyReLU(0.2), + Conv2dELR( 256, 256, 1, 1, 0, demod="demod" if demod else None), nn.LeakyReLU(0.2), + Conv2dELR( 256, 3, 1, 1, 0, demod=False)) + + def forward(self, bg=None, camindex=None, raypos=None, rayposend=None, + raydir=None, samplecoords=None, trainiter=-1, **kwargs): + if self.trainstart > -1 and trainiter >= self.trainstart:# and camindex is not None: + # generate position encoding + posenc = torch.cat([ + torch.sin(2 ** i * np.pi * rayposend[:, :, :, :]) + for i in range(10)] + [ + torch.cos(2 ** i * np.pi * rayposend[:, :, :, :]) + for i in range(10)], dim=-1).permute(0, 3, 1, 2) + + direnc = torch.cat([ + torch.sin(2 ** i * np.pi * raydir[:, :, :, :]) + for i in range(4)] + [ + torch.cos(2 ** i * np.pi * raydir[:, :, :, :]) + for i in range(4)], dim=-1).permute(0, 3, 1, 2) + + decout = torch.cat([posenc, direnc], dim=1) + decout = self.mlp1(decout) + + decout = torch.cat([posenc, direnc, decout], dim=1) + decout = self.mlp2(decout) + else: + decout = None + + if bg is None and self.bg is not None and camindex is not None: + bg = torch.stack([self.bg[self.allcameras[camindex[i].item()]] for i in range(camindex.size(0))], dim=0) + else: + bg = None + + if bg is not None and samplecoords is not None: + if samplecoords.size()[1:3] != bg.size()[2:4]: + bg = F.grid_sample(bg, samplecoords, align_corners=False) + + if decout is not None: + if bg is not None: + return F.softplus(bg + decout) + else: + return F.softplus(decout) + else: + if bg is not None: + return F.softplus(bg) + else: + return None diff --git a/dva/mvp/models/colorcals/colorcal.py b/dva/mvp/models/colorcals/colorcal.py new file mode 100644 index 0000000000000000000000000000000000000000..73b58b819148eb1c9c78fe1aaeb4a0632af0932d --- /dev/null +++ b/dva/mvp/models/colorcals/colorcal.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Colorcal(nn.Module): + def __init__(self, allcameras): + super(Colorcal, self).__init__() + + self.allcameras = allcameras + + self.weight = nn.Parameter( + torch.ones(len(self.allcameras), 3)) + self.bias = nn.Parameter( + torch.zeros(len(self.allcameras), 3)) + + def forward(self, image, camindex): + # collect weights + weight = self.weight[camindex] + bias = self.bias[camindex] + + # reshape + b = image.size(0) + groups = b * 3 + image = image.view(1, -1, image.size(2), image.size(3)) + weight = weight.view(-1, 1, 1, 1) + bias = bias.view(-1) + + # conv + result = F.conv2d(image, weight, bias, groups=groups) + + # unshape + result = result.view(b, 3, image.size(2), image.size(3)) + return result + + def parameters(self): + for p in super(Colorcal, self).parameters(): + if p.requires_grad: + yield p diff --git a/dva/mvp/models/decoders/mvp.py b/dva/mvp/models/decoders/mvp.py new file mode 100644 index 0000000000000000000000000000000000000000..c772f53c0990d4c19156df79b34daa3d4367bc57 --- /dev/null +++ b/dva/mvp/models/decoders/mvp.py @@ -0,0 +1,641 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" MVP decoder """ +import math +from typing import Optional, Dict, List + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import models.utils +from models.utils import LinearELR, ConvTranspose2dELR, ConvTranspose3dELR + +@torch.jit.script +def compute_postex(geo, idxim, barim, volradius : float): + # compute 3d coordinates of each texel in uv map + return ( + barim[None, :, :, 0, None] * geo[:, idxim[:, :, 0], :] + + barim[None, :, :, 1, None] * geo[:, idxim[:, :, 1], :] + + barim[None, :, :, 2, None] * geo[:, idxim[:, :, 2], :] + ).permute(0, 3, 1, 2) / volradius + +@torch.jit.script +def compute_tbn(v0, v1, v2, vt0, vt1, vt2): + v01 = v1 - v0 + v02 = v2 - v0 + vt01 = vt1 - vt0 + vt02 = vt2 - vt0 + f = 1. / (vt01[None, :, :, 0] * vt02[None, :, :, 1] - vt01[None, :, :, 1] * vt02[None, :, :, 0]) + tangent = f[:, :, :, None] * torch.stack([ + v01[:, :, :, 0] * vt02[None, :, :, 1] - v02[:, :, :, 0] * vt01[None, :, :, 1], + v01[:, :, :, 1] * vt02[None, :, :, 1] - v02[:, :, :, 1] * vt01[None, :, :, 1], + v01[:, :, :, 2] * vt02[None, :, :, 1] - v02[:, :, :, 2] * vt01[None, :, :, 1]], dim=-1) + tangent = F.normalize(tangent, dim=-1) + normal = torch.cross(v01, v02, dim=3) + normal = F.normalize(normal, dim=-1) + bitangent = torch.cross(tangent, normal, dim=3) + bitangent = F.normalize(bitangent, dim=-1) + + # create matrix + primrotmesh = torch.stack((tangent, bitangent, normal), dim=-1) + + return primrotmesh + +class Reshape(nn.Module): + def __init__(self, *args): + super(Reshape, self).__init__() + self.shape = args + + def forward(self, x): + return x.view(self.shape) + +# RGBA decoder +class SlabContentDecoder(nn.Module): + def __init__(self, nprims, primsize, inch, outch, chstart=256, hstart=4, + texwarp=False, elr=True, norm=None, mod=False, ub=True, upconv=None, + penultch=None, use3dconv=False, reduced3dch=False): + super(SlabContentDecoder, self).__init__() + + assert not texwarp + assert upconv == None + + self.nprims = nprims + self.primsize = primsize + + self.nprimy = int(math.sqrt(nprims)) + self.nprimx = nprims // self.nprimy + assert nprims == self.nprimx * self.nprimy + + self.slabw = self.nprimx * primsize[0] + self.slabh = self.nprimy * primsize[1] + self.slabd = primsize[2] + + nlayers = int(math.log2(min(self.slabw, self.slabh))) - int(math.log2(hstart)) + nlayers3d = int(math.log2(self.slabd)) + nlayers2d = nlayers - nlayers3d + + lastch = chstart + dims = (1, hstart, hstart * self.nprimx // self.nprimy) + + layers = [] + layers.append(LinearELR(inch, chstart*dims[1]*dims[2], act=nn.LeakyReLU(0.2))) + layers.append(Reshape(-1, chstart, dims[1], dims[2])) + + for i in range(nlayers): + nextch = lastch if i % 2 == 0 else lastch // 2 + + if use3dconv and reduced3dch and i >= nlayers2d: + nextch //= 2 + + if i == nlayers - 2 and penultch is not None: + nextch = penultch + + if use3dconv and i >= nlayers2d: + if i == nlayers2d: + layers.append(Reshape(-1, lastch, 1, dims[1], dims[2])) + layers.append(ConvTranspose3dELR( + lastch, + (outch if i == nlayers - 1 else nextch), + 4, 2, 1, + ub=(dims[0]*2, dims[1]*2, dims[2]*2) if ub else None, + norm=None if i == nlayers - 1 else norm, + act=None if i == nlayers - 1 else nn.LeakyReLU(0.2) + )) + else: + layers.append(ConvTranspose2dELR( + lastch, + (outch * primsize[2] if i == nlayers - 1 else nextch), + 4, 2, 1, + ub=(dims[1]*2, dims[2]*2) if ub else None, + norm=None if i == nlayers - 1 else norm, + act=None if i == nlayers - 1 else nn.LeakyReLU(0.2) + )) + + lastch = nextch + dims = (dims[0] * (2 if use3dconv and i >= nlayers2d else 1), dims[1] * 2, dims[2] * 2) + + self.mod = nn.Sequential(*layers) + + def forward(self, enc, renderoptions : Dict[str, str], trainiter : Optional[int]=None): + x = self.mod(enc) + + algo = renderoptions.get("algo") + chlast = renderoptions.get("chlast") + + if chlast is not None and bool(chlast): + # reorder channels last + if len(x.size()) == 5: + outch = x.size(1) + x = x.view(x.size(0), outch, self.primsize[2], self.nprimy, self.primsize[1], self.nprimx, self.primsize[0]) + x = x.permute(0, 3, 5, 2, 4, 6, 1) + x = x.reshape(x.size(0), self.nprims, self.primsize[2], self.primsize[1], self.primsize[0], outch) + else: + outch = x.size(1) // self.primsize[2] + x = x.view(x.size(0), self.primsize[2], outch, self.nprimy, self.primsize[1], self.nprimx, self.primsize[0]) + x = x.permute(0, 3, 5, 1, 4, 6, 2) + x = x.reshape(x.size(0), self.nprims, self.primsize[2], self.primsize[1], self.primsize[0], outch) + else: + if len(x.size()) == 5: + outch = x.size(1) + x = x.view(x.size(0), outch, self.primsize[2], self.nprimy, self.primsize[1], self.nprimx, self.primsize[0]) + x = x.permute(0, 3, 5, 1, 2, 4, 6) + x = x.reshape(x.size(0), self.nprims, outch, self.primsize[2], self.primsize[1], self.primsize[0]) + else: + outch = x.size(1) // self.primsize[2] + x = x.view(x.size(0), self.primsize[2], outch, self.nprimy, self.primsize[1], self.nprimx, self.primsize[0]) + x = x.permute(0, 3, 5, 2, 1, 4, 6) + x = x.reshape(x.size(0), self.nprims, outch, self.primsize[2], self.primsize[1], self.primsize[0]) + + return x + +def get_dec(dectype, **kwargs): + if dectype == "slab2d": + return SlabContentDecoder(**kwargs, use3dconv=False) + elif dectype == "slab2d3d": + return SlabContentDecoder(**kwargs, use3dconv=True) + elif dectype == "slab2d3dv2": + return SlabContentDecoder(**kwargs, use3dconv=True, reduced3dch=True) + else: + raise + +# motion model for the delta from mesh-based position/orientation +class DeconvMotionModel(nn.Module): + def __init__(self, nprims, inch, outch, chstart=1024, + norm=None, mod=False, elr=True): + super(DeconvMotionModel, self).__init__() + + self.nprims = nprims + + self.nprimy = int(math.sqrt(nprims)) + self.nprimx = nprims // int(math.sqrt(nprims)) + assert nprims == self.nprimx * self.nprimy + + nlayers = int(math.log2(min(self.nprimx, self.nprimy))) + + ch0, ch1 = chstart, chstart // 2 + layers = [] + + layers.append(LinearELR(inch, ch0, norm=norm, act=nn.LeakyReLU(0.2))) + + layers.append(Reshape(-1, ch0, 1, self.nprimx // self.nprimy)) + dims = (1, 1, self.nprimx // self.nprimy) + + for i in range(nlayers): + layers.append(ConvTranspose2dELR( + ch0, + (outch if i == nlayers - 1 else ch1), + 4, 2, 1, + norm=None if i == nlayers - 1 else norm, + act=None if i == nlayers - 1 else nn.LeakyReLU(0.2) + )) + + if ch0 == ch1: + ch1 = ch0 // 2 + else: + ch0 = ch1 + + self.mod = nn.Sequential(*layers) + + def forward(self, encoding): + out = self.mod(encoding) + out = out.view(encoding.size(0), 9, -1).permute(0, 2, 1).contiguous() + + primposdelta = out[:, :, 0:3] + primrvecdelta = out[:, :, 3:6] + primscaledelta = out[:, :, 6:9] + return primposdelta, primrvecdelta, primscaledelta + +def get_motion(motiontype, **kwargs): + if motiontype == "deconv": + return DeconvMotionModel(**kwargs) + else: + raise + +class Decoder(nn.Module): + def __init__(self, + vt, + vertmean, + vertstd, + idxim, + tidxim, + barim, + volradius, + dectype="slab2d", + nprims=512, + primsize=(32, 32, 32), + chstart=256, + penultch=None, + condsize=0, + motiontype="deconv", + warptype=None, + warpprimsize=None, + sharedrgba=False, + norm=None, + mod=False, + elr=True, + scalemult=2., + nogeo=False, + notplateact=False, + postrainstart=-1, + alphatrainstart=-1, + renderoptions={}, + **kwargs): + """ + Parameters + ---------- + vt : numpy.array [V, 2] + mesh vertex texture coordinates + vertmean : numpy.array [V, 3] + mesh vertex position average (average over time) + vertstd : float + mesh vertex position standard deviation (over time) + idxim : torch.Tensor + texture map of triangle indices + tidxim : torch.Tensor + texture map of texture triangle indices + barim : torch.Tensor + texture map of barycentric coordinates + volradius : float + radius of bounding volume of scene + dectype : string + type of content decoder, options are "slab2d", "slab2d3d", "slab2d3dv2" + nprims : int + number of primitives + primsize : Tuple[int, int, int] + size of primitive dimensions + postrainstart : int + training iterations to start learning position, rotation, and + scaling (i.e., primitives stay frozen until this iteration number) + condsize : int + unused + motiontype : string + motion model, options are "linear" and "deconv" + warptype : string + warp model, options are "same" to use same architecture as content + or None + sharedrgba : bool + True to use 1 branch to output rgba, False to use 1 branch for rgb + and 1 branch for alpha + """ + super(Decoder, self).__init__() + + self.volradius = volradius + self.postrainstart = postrainstart + self.alphatrainstart = alphatrainstart + + self.nprims = nprims + self.primsize = primsize + + self.motiontype = motiontype + self.nogeo = nogeo + self.notplateact = notplateact + self.scalemult = scalemult + + self.enc = LinearELR(256 + condsize, 256) + + # vertex output + if not self.nogeo: + self.geobranch = LinearELR(256, vertmean.numel(), norm=None) + + # primitive motion delta decoder + self.motiondec = get_motion(motiontype, nprims=nprims, inch=256, outch=9, + norm=norm, mod=mod, elr=elr, **kwargs) + + # slab decoder (RGBA) + if sharedrgba: + self.rgbadec = get_dec(dectype, nprims=nprims, primsize=primsize, + inch=256+3, outch=4, norm=norm, mod=mod, elr=elr, + penultch=penultch, **kwargs) + + if renderoptions.get("half", False): + self.rgbadec = self.rgbadec.half() + + if renderoptions.get("chlastconv", False): + self.rgbadec = self.rgbadec.to(memory_format=torch.channels_last) + else: + self.rgbdec = get_dec(dectype, nprims=nprims, primsize=primsize, + inch=256+3, outch=3, chstart=chstart, norm=norm, mod=mod, + elr=elr, penultch=penultch, **kwargs) + self.alphadec = get_dec(dectype, nprims=nprims, primsize=primsize, + inch=256, outch=1, chstart=chstart, norm=norm, mod=mod, + elr=elr, penultch=penultch, **kwargs) + self.rgbadec = None + + if renderoptions.get("half", False): + self.rgbdec = self.rgbdec.half() + self.alphadec = self.alphadec.half() + + if renderoptions.get("chlastconv", False): + self.rgbdec = self.rgbdec.to(memory_format=torch.channels_last) + self.alphadec = self.alphadec.to(memory_format=torch.channels_last) + + # warp field decoder + if warptype is not None: + self.warpdec = get_dec(warptype, nprims=nprims, primsize=warpprimsize, + inch=256, outch=3, chstart=chstart, norm=norm, mod=mod, elr=elr, **kwargs) + else: + self.warpdec = None + + # vertex/triangle/mesh topology data + if vt is not None: + vt = torch.tensor(vt) if not isinstance(vt, torch.Tensor) else vt + self.register_buffer("vt", vt, persistent=False) + + if vertmean is not None: + self.register_buffer("vertmean", vertmean, persistent=False) + self.vertstd = vertstd + + idxim = torch.tensor(idxim) if not isinstance(idxim, torch.Tensor) else idxim + tidxim = torch.tensor(tidxim) if not isinstance(tidxim, torch.Tensor) else tidxim + barim = torch.tensor(barim) if not isinstance(barim, torch.Tensor) else barim + self.register_buffer("idxim", idxim.long(), persistent=False) + self.register_buffer("tidxim", tidxim.long(), persistent=False) + self.register_buffer("barim", barim, persistent=False) + + def forward(self, + encoding, + viewpos, + condinput : Optional[torch.Tensor]=None, + renderoptions : Optional[Dict[str, str]]=None, + trainiter : int=-1, + evaliter : Optional[torch.Tensor]=None, + losslist : Optional[List[str]]=None, + modelmatrix : Optional[torch.Tensor]=None): + """ + Parameters + ---------- + encoding : torch.Tensor [B, 256] + Encoding of current frame + viewpos : torch.Tensor [B, 3] + Viewing position of target camera view + condinput : torch.Tensor [B, ?] + Additional conditioning input (e.g., headpose) + renderoptions : dict + Options for rendering (e.g., rendering debug images) + trainiter : int, + Current training iteration + losslist : list, + List of losses to compute and return + + Returns + ------- + result : dict, + Contains predicted vertex positions, primitive contents and + locations, scaling, and orientation, and any losses. + """ + assert renderoptions is not None + assert losslist is not None + + if condinput is not None: + encoding = torch.cat([encoding, condinput], dim=1) + + encoding = self.enc(encoding) + + viewdirs = F.normalize(viewpos, dim=1) + + if int(math.sqrt(self.nprims)) ** 2 == self.nprims: + nprimsy = int(math.sqrt(self.nprims)) + else: + nprimsy = int(math.sqrt(self.nprims // 2)) + nprimsx = self.nprims // nprimsy + + assert nprimsx * nprimsy == self.nprims + + if not self.nogeo: + # decode mesh vertices + # geo [6, 7306, 3] + geo = self.geobranch(encoding) + geo = geo.view(encoding.size(0), -1, 3) + geo = geo * self.vertstd + self.vertmean + + # placement of primitives on mesh + uvheight, uvwidth = self.barim.size(0), self.barim.size(1) + stridey = uvheight // nprimsy + stridex = uvwidth // nprimsx + + # get subset of vertices and texture map coordinates to compute TBN matrix + v0 = geo[:, self.idxim[stridey//2::stridey, stridex//2::stridex, 0], :] + v1 = geo[:, self.idxim[stridey//2::stridey, stridex//2::stridex, 1], :] + v2 = geo[:, self.idxim[stridey//2::stridey, stridex//2::stridex, 2], :] + + vt0 = self.vt[self.tidxim[stridey//2::stridey, stridex//2::stridex, 0], :] + vt1 = self.vt[self.tidxim[stridey//2::stridey, stridex//2::stridex, 1], :] + vt2 = self.vt[self.tidxim[stridey//2::stridey, stridex//2::stridex, 2], :] + + # [6, 256, 3] + primposmesh = ( + self.barim[None, stridey//2::stridey, stridex//2::stridex, 0, None] * v0 + + self.barim[None, stridey//2::stridey, stridex//2::stridex, 1, None] * v1 + + self.barim[None, stridey//2::stridey, stridex//2::stridex, 2, None] * v2 + ).view(v0.size(0), self.nprims, 3) / self.volradius + + # compute TBN matrix + # primrotmesh [6, 16, 16, 3, 3] + primrotmesh = compute_tbn(v0, v1, v2, vt0, vt1, vt2) + + # decode motion deltas [6, 256, 3] + primposdelta, primrvecdelta, primscaledelta = self.motiondec(encoding) + if trainiter <= self.postrainstart: + primposdelta = primposdelta * 0. + primrvecdelta = primrvecdelta * 0. + primscaledelta = primscaledelta * 0. + + # compose mesh transform with deltas + primpos = primposmesh + primposdelta * 0.01 + primrotdelta = models.utils.axisangle_to_matrix(primrvecdelta * 0.01) + primrot = torch.bmm( + primrotmesh.view(-1, 3, 3), + primrotdelta.view(-1, 3, 3)).view(encoding.size(0), self.nprims, 3, 3) + primscale = (self.scalemult * int(self.nprims ** (1. / 3))) * torch.exp(primscaledelta * 0.01) + + primtransf = None + else: + geo = None + + # decode motion deltas + primposdelta, primrvecdelta, primscaledelta = self.motiondec(encoding) + if trainiter <= self.postrainstart: + primposdelta = primposdelta * 0. + primrvecdelta = primrvecdelta * 0. + primscaledelta = primscaledelta * 0. + 1. + + primpos = primposdelta * 0.3 + primrotdelta = models.utils.axisangle_to_matrix(primrvecdelta * 0.3) + primrot = torch.exp(primrotdelta * 0.01) + primscale = (self.scalemult * int(self.nprims ** (1. / 3))) * primscaledelta + + primtransf = None + + # options + algo = renderoptions.get("algo") + chlast = renderoptions.get("chlast") + half = renderoptions.get("half") + + if self.rgbadec is not None: + # shared rgb and alpha branch + scale = torch.tensor([25., 25., 25., 1.], device=encoding.device) + bias = torch.tensor([100., 100., 100., 0.], device=encoding.device) + if chlast is not None and bool(chlast): + scale = scale[None, None, None, None, None, :] + bias = bias[None, None, None, None, None, :] + else: + scale = scale[None, None, :, None, None, None] + bias = bias[None, None, :, None, None, None] + + templatein = torch.cat([encoding, viewdirs], dim=1) + if half is not None and bool(half): + templatein = templatein.half() + template = self.rgbadec(templatein, trainiter=trainiter, renderoptions=renderoptions) + template = bias + scale * template + if not self.notplateact: + template = F.relu(template) + if half is not None and bool(half): + template = template.float() + else: + templatein = torch.cat([encoding, viewdirs], dim=1) + if half is not None and bool(half): + templatein = templatein.half() + # primrgb [6, 256, 32, 32, 32, 3] -> [B, 256, primsize, 3] + primrgb = self.rgbdec(templatein, trainiter=trainiter, renderoptions=renderoptions) + primrgb = primrgb * 25. + 100. + if not self.notplateact: + primrgb = F.relu(primrgb) + + templatein = encoding + if half is not None and bool(half): + templatein = templatein.half() + primalpha = self.alphadec(templatein, trainiter=trainiter, renderoptions=renderoptions) + if not self.notplateact: + primalpha = F.relu(primalpha) + + if trainiter <= self.alphatrainstart: + primalpha = primalpha * 0. + 1. + + if algo is not None and int(algo) == 4: + template = torch.cat([primrgb, primalpha], dim=-1) + elif chlast is not None and bool(chlast): + template = torch.cat([primrgb, primalpha], dim=-1) + else: + template = torch.cat([primrgb, primalpha], dim=2) + if half is not None and bool(half): + template = template.float() + + if self.warpdec is not None: + warp = self.warpdec(encoding, trainiter=trainiter, renderoptions=renderoptions) * 0.01 + warp = warp + torch.stack(torch.meshgrid( + torch.linspace(-1., 1., self.primsize[2], device=encoding.device), + torch.linspace(-1., 1., self.primsize[1], device=encoding.device), + torch.linspace(-1., 1., self.primsize[0], device=encoding.device))[::-1], + dim=-1 if chlast is not None and bool(chlast) else 0)[None, None, :, :, :, :] + else: + warp = None + + # debugging / visualization + viewaxes = renderoptions.get("viewaxes") + colorprims = renderoptions.get("colorprims") + viewslab = renderoptions.get("viewslab") + + # add axes to primitives + if viewaxes is not None and bool(viewaxes): + template[:, :, 3, template.size(3)//2:template.size(3)//2+1, template.size(4)//2:template.size(4)//2+1, :] = 2550. + template[:, :, 0, template.size(3)//2:template.size(3)//2+1, template.size(4)//2:template.size(4)//2+1, :] = 2550. + template[:, :, 3, template.size(3)//2:template.size(3)//2+1, :, template.size(5)//2:template.size(5)//2+1] = 2550. + template[:, :, 1, template.size(3)//2:template.size(3)//2+1, :, template.size(5)//2:template.size(5)//2+1] = 2550. + template[:, :, 3, :, template.size(4)//2:template.size(4)//2+1, template.size(5)//2:template.size(5)//2+1] = 2550. + template[:, :, 2, :, template.size(4)//2:template.size(4)//2+1, template.size(5)//2:template.size(5)//2+1] = 2550. + + # give each primitive a unique color + if colorprims is not None and bool(colorprims): + lightdir = -torch.tensor([1., 1., 1.], device=template.device) + lightdir = lightdir / torch.sqrt(torch.sum(lightdir ** 2)) + zz, yy, xx = torch.meshgrid( + torch.linspace(-1., 1., self.primsize[2], device=template.device), + torch.linspace(-1., 1., self.primsize[1], device=template.device), + torch.linspace(-1., 1., self.primsize[0], device=template.device)) + primnormalx = torch.where( + (torch.abs(xx) >= torch.abs(yy)) & (torch.abs(xx) >= torch.abs(zz)), + torch.sign(xx) * torch.ones_like(xx), + torch.zeros_like(xx)) + primnormaly = torch.where( + (torch.abs(yy) >= torch.abs(xx)) & (torch.abs(yy) >= torch.abs(zz)), + torch.sign(yy) * torch.ones_like(xx), + torch.zeros_like(xx)) + primnormalz = torch.where( + (torch.abs(zz) >= torch.abs(xx)) & (torch.abs(zz) >= torch.abs(yy)), + torch.sign(zz) * torch.ones_like(xx), + torch.zeros_like(xx)) + primnormal = torch.stack([primnormalx, primnormaly, primnormalz], dim=-1) + primnormal = F.normalize(primnormal, dim=-1) + + torch.manual_seed(123456) + + gridz, gridy, gridx = torch.meshgrid( + torch.linspace(-1., 1., self.primsize[2], device=encoding.device), + torch.linspace(-1., 1., self.primsize[1], device=encoding.device), + torch.linspace(-1., 1., self.primsize[0], device=encoding.device)) + grid = torch.stack([gridx, gridy, gridz], dim=-1) + + if chlast is not None and chlast: + template[:] = torch.rand(1, template.size(1), 1, 1, 1, template.size(-1), device=template.device) * 255. + template[:, :, :, :, :, 3] = 1000. + else: + template[:] = torch.rand(1, template.size(1), template.size(2), 1, 1, 1, device=template.device) * 255. + template[:, :, 3, :, :, :] = 1000. + + if chlast is not None and chlast: + lightdir0 = torch.sum(primrot[:, :, :, :] * lightdir[None, None, :, None], dim=-2) + template[:, :, :, :, :, :3] *= 1.2 * torch.sum( + lightdir0[:, :, None, None, None, :] * primnormal, dim=-1)[:, :, :, :, :, None].clamp(min=0.05) + else: + lightdir0 = torch.sum(primrot[:, :, :, :] * lightdir[None, None, :, None], dim=-2) + template[:, :, :3, :, :, :] *= 1.2 * torch.sum( + lightdir0[:, :, None, None, None, :] * primnormal, dim=-1)[:, :, None, :, :, :].clamp(min=0.05) + + # view slab as a 2d grid + if viewslab is not None and bool(viewslab): + assert evaliter is not None + + yy, xx = torch.meshgrid( + torch.linspace(0., 1., int(math.sqrt(self.nprims)), device=template.device), + torch.linspace(0., 1., int(math.sqrt(self.nprims)), device=template.device)) + primpos0 = torch.stack([xx*1.5, 0.75-yy*1.5, xx*0.+0.5], dim=-1)[None, :, :, :].repeat(template.size(0), 1, 1, 1).view(-1, self.nprims, 3) + primrot0 = torch.eye(3, device=template.device)[None, None, :, :].repeat(template.size(0), self.nprims, 1, 1) + primrot0.data[:, :, 1, 1] *= -1. + primscale0 = torch.ones((template.size(0), self.nprims, 3), device=template.device) * math.sqrt(self.nprims) * 1.25 #* 0.5 + + blend = ((evaliter - 256.) / 64.).clamp(min=0., max=1.)[:, None, None] + blend = 3 * blend ** 2 - 2 * blend ** 3 + + primpos = (1. - blend) * primpos0 + blend * primpos + primrot = models.utils.rotation_interp(primrot0, primrot, blend) + primscale = torch.exp((1. - blend) * torch.log(primscale0) + blend * torch.log(primscale)) + + losses = {} + + # prior on primitive volume + if "primvolsum" in losslist: + losses["primvolsum"] = torch.sum(torch.prod(1. / primscale, dim=-1), dim=-1) + + if "logprimscalevar" in losslist: + logprimscale = torch.log(primscale) + logprimscalemean = torch.mean(logprimscale, dim=1, keepdim=True) + losses["logprimscalevar"] = torch.mean((logprimscale - logprimscalemean) ** 2) + + result = { + "template": template, + "primpos": primpos, + "primrot": primrot, + "primscale": primscale} + if primtransf is not None: + result["primtransf"] = primtransf + if warp is not None: + result["warp"] = warp + if geo is not None: + result["verts"] = geo + return result, losses diff --git a/dva/mvp/models/decoders/nv.py b/dva/mvp/models/decoders/nv.py new file mode 100644 index 0000000000000000000000000000000000000000..0c99ce2295c3e9ea4701f424b52eb5820d9c14b2 --- /dev/null +++ b/dva/mvp/models/decoders/nv.py @@ -0,0 +1,306 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Neural Volumes decoder """ +import math +from typing import Optional, Dict, List + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import models.utils +from models.utils import LinearELR, ConvTranspose2dELR, ConvTranspose3dELR + +class Reshape(nn.Module): + def __init__(self, *args): + super(Reshape, self).__init__() + self.shape = args + + def forward(self, x): + return x.view(self.shape) + +class ContentDecoder(nn.Module): + def __init__(self, primsize, inch, outch, chstart=256, hstart=4, + texwarp=False, elr=True, norm=None, mod=False, ub=True, upconv=None, + penultch=None): + super(ContentDecoder, self).__init__() + + assert not texwarp + assert upconv == None + + self.primsize = primsize + + nlayers = int(math.log2(self.primsize / hstart)) + + lastch = chstart + dims = (hstart, hstart, hstart) + + layers = [] + layers.append(LinearELR(inch, chstart*dims[0]*dims[1]*dims[2], act=nn.LeakyReLU(0.2))) + layers.append(Reshape(-1, chstart, dims[0], dims[1], dims[2])) + + for i in range(nlayers): + nextch = lastch if i % 2 == 0 else lastch // 2 + + if i == nlayers - 2 and penultch is not None: + nextch = penultch + + layers.append(ConvTranspose3dELR( + lastch, + (outch if i == nlayers - 1 else nextch), + 4, 2, 1, + ub=(dims[0]*2, dims[1]*2, dims[2]*2) if ub else None, + norm=None if i == nlayers - 1 else norm, + act=None if i == nlayers - 1 else nn.LeakyReLU(0.2) + )) + + lastch = nextch + dims = (dims[0] * 2, dims[1] * 2, dims[2] * 2) + + self.mod = nn.Sequential(*layers) + + def forward(self, enc, renderoptions : Dict[str, str], trainiter : Optional[int]=None): + x = self.mod(enc) + + algo = renderoptions.get("algo") + chlast = renderoptions.get("chlast") + + if chlast is not None and bool(chlast): + # reorder channels last + outch = x.size(1) + x = x.permute(0, 2, 3, 4, 1)[:, None, :, :, :, :].contiguous() + else: + outch = x.size(1) + x = x[:, None, :, :, :, :].contiguous() + + return x + +def get_dec(dectype, **kwargs): + if dectype == "conv": + return ContentDecoder(**kwargs) + else: + raise + +class Decoder(nn.Module): + def __init__(self, + volradius, + dectype="conv", + primsize=128, + chstart=256, + penultch=None, + condsize=0, + warptype="conv", + warpprimsize=32, + sharedrgba=False, + norm=None, + mod=False, + elr=True, + notplateact=False, + postrainstart=-1, + alphatrainstart=-1, + renderoptions={}, + **kwargs): + """ + Parameters + ---------- + volradius : float + radius of bounding volume of scene + dectype : string + type of content decoder, options are "slab2d", "slab2d3d", "slab2d3dv2" + primsize : Tuple[int, int, int] + size of primitive dimensions + postrainstart : int + training iterations to start learning position, rotation, and + scaling (i.e., primitives stay frozen until this iteration number) + condsize : int + unused + motiontype : string + motion model, options are "linear" and "deconv" + warptype : string + warp model, options are "same" to use same architecture as content + or None + sharedrgba : bool + True to use 1 branch to output rgba, False to use 1 branch for rgb + and 1 branch for alpha + """ + super(Decoder, self).__init__() + + self.volradius = volradius + self.postrainstart = postrainstart + self.alphatrainstart = alphatrainstart + + self.primsize = primsize + self.warpprimsize = warpprimsize + + self.notplateact = notplateact + + self.enc = LinearELR(256 + condsize, 256) + + # slab decoder (RGBA) + if sharedrgba: + self.rgbadec = get_dec(dectype, primsize=primsize, + inch=256+3, outch=4, norm=norm, mod=mod, elr=elr, + penultch=penultch, **kwargs) + + if renderoptions.get("half", False): + self.rgbadec = self.rgbadec.half() + + if renderoptions.get("chlastconv", False): + self.rgbadec = self.rgbadec.to(memory_format=torch.channels_last) + else: + self.rgbdec = get_dec(dectype, primsize=primsize, + inch=256+3, outch=3, chstart=chstart, norm=norm, mod=mod, + elr=elr, penultch=penultch, **kwargs) + self.alphadec = get_dec(dectype, primsize=primsize, + inch=256, outch=1, chstart=chstart, norm=norm, mod=mod, + elr=elr, penultch=penultch, **kwargs) + self.rgbadec = None + + if renderoptions.get("half", False): + self.rgbdec = self.rgbdec.half() + self.alphadec = self.alphadec.half() + + if renderoptions.get("chlastconv", False): + self.rgbdec = self.rgbdec.to(memory_format=torch.channels_last) + self.alphadec = self.alphadec.to(memory_format=torch.channels_last) + + # warp field decoder + if warptype is not None: + self.warpdec = get_dec(warptype, primsize=warpprimsize, + inch=256, outch=3, chstart=chstart, norm=norm, mod=mod, elr=elr, **kwargs) + else: + self.warpdec = None + + def forward(self, + encoding, + viewpos, + condinput : Optional[torch.Tensor]=None, + renderoptions : Optional[Dict[str, str]]=None, + trainiter : int=-1, + evaliter : Optional[torch.Tensor]=None, + losslist : Optional[List[str]]=None, + modelmatrix : Optional[torch.Tensor]=None): + """ + Parameters + ---------- + encoding : torch.Tensor [B, 256] + Encoding of current frame + viewpos : torch.Tensor [B, 3] + Viewing position of target camera view + condinput : torch.Tensor [B, ?] + Additional conditioning input (e.g., headpose) + renderoptions : dict + Options for rendering (e.g., rendering debug images) + trainiter : int, + Current training iteration + losslist : list, + List of losses to compute and return + + Returns + ------- + result : dict, + Contains predicted vertex positions, primitive contents and + locations, scaling, and orientation, and any losses. + """ + assert renderoptions is not None + assert losslist is not None + + if condinput is not None: + encoding = torch.cat([encoding, condinput], dim=1) + + encoding = self.enc(encoding) + + viewdirs = F.normalize(viewpos, dim=1) + + primpos = torch.zeros(encoding.size(0), 1, 3, device=encoding.device) + primrot = torch.eye(3, device=encoding.device)[None, None, :, :].repeat(encoding.size(0), 1, 1, 1) + primscale = torch.ones(encoding.size(0), 1, 3, device=encoding.device) + + # options + algo = renderoptions.get("algo") + chlast = renderoptions.get("chlast") + half = renderoptions.get("half") + + if self.rgbadec is not None: + # shared rgb and alpha branch + scale = torch.tensor([25., 25., 25., 1.], device=encoding.device) + bias = torch.tensor([100., 100., 100., 0.], device=encoding.device) + if chlast is not None and bool(chlast): + scale = scale[None, None, None, None, None, :] + bias = bias[None, None, None, None, None, :] + else: + scale = scale[None, None, :, None, None, None] + bias = bias[None, None, :, None, None, None] + + templatein = torch.cat([encoding, viewdirs], dim=1) + if half is not None and bool(half): + templatein = templatein.half() + template = self.rgbadec(templatein, trainiter=trainiter, renderoptions=renderoptions) + template = bias + scale * template + if not self.notplateact: + template = F.relu(template) + if half is not None and bool(half): + template = template.float() + else: + templatein = torch.cat([encoding, viewdirs], dim=1) + if half is not None and bool(half): + templatein = templatein.half() + primrgb = self.rgbdec(templatein, trainiter=trainiter, renderoptions=renderoptions) + primrgb = primrgb * 25. + 100. + if not self.notplateact: + primrgb = F.relu(primrgb) + + templatein = encoding + if half is not None and bool(half): + templatein = templatein.half() + primalpha = self.alphadec(templatein, trainiter=trainiter, renderoptions=renderoptions) + if not self.notplateact: + primalpha = F.relu(primalpha) + + if trainiter <= self.alphatrainstart: + primalpha = primalpha * 0. + 1. + + if algo is not None and int(algo) == 4: + template = torch.cat([primrgb, primalpha], dim=-1) + elif chlast is not None and bool(chlast): + template = torch.cat([primrgb, primalpha], dim=-1) + else: + template = torch.cat([primrgb, primalpha], dim=2) + if half is not None and bool(half): + template = template.float() + + if self.warpdec is not None: + warp = self.warpdec(encoding, trainiter=trainiter, renderoptions=renderoptions) * 0.01 + warp = warp + torch.stack(torch.meshgrid( + torch.linspace(-1., 1., self.warpprimsize, device=encoding.device), + torch.linspace(-1., 1., self.warpprimsize, device=encoding.device), + torch.linspace(-1., 1., self.warpprimsize, device=encoding.device))[::-1], + dim=-1 if chlast is not None and bool(chlast) else 0)[None, None, :, :, :, :] + warp = warp.contiguous() + else: + warp = None + + losses = {} + + # prior on primitive volume + if "primvolsum" in losslist: + losses["primvolsum"] = torch.sum(torch.prod(1. / primscale, dim=-1), dim=-1) + + if "logprimscalevar" in losslist: + logprimscale = torch.log(primscale) + logprimscalemean = torch.mean(logprimscale, dim=1, keepdim=True) + losses["logprimscalevar"] = torch.mean((logprimscale - logprimscalemean) ** 2) + + result = { + "template": template, + "primpos": primpos, + "primrot": primrot, + "primscale": primscale} + if warp is not None: + result["warp"] = warp + return result, losses diff --git a/dva/mvp/models/encoders/geotex.py b/dva/mvp/models/encoders/geotex.py new file mode 100644 index 0000000000000000000000000000000000000000..b0b1fc5a13ba67126930a8cebb3f0c4925be3d15 --- /dev/null +++ b/dva/mvp/models/encoders/geotex.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +from typing import Optional, List + +import numpy as np + +import torch +import torch.nn as nn + +from models.utils import LinearELR, Conv2dELR + +class Encoder(torch.nn.Module): + def __init__(self, latentdim=256, hiq=True, texin=True, + conv=Conv2dELR, lin=LinearELR, + demod=True, texsize=1024, vertsize=21918): + super(Encoder, self).__init__() + + self.latentdim = latentdim + + self.vertbranch = lin(vertsize, 256, norm="demod", act=nn.LeakyReLU(0.2)) + if texin: + cm = 2 if hiq else 1 + + layers = [] + chout = 128*cm + chin = 128*cm + nlayers = int(np.log2(texsize)) - 2 + for i in range(nlayers): + if i == nlayers - 1: + chin = 3 + layers.append( + conv(chin, chout, 4, 2, 1, norm="demod" if demod else None, act=nn.LeakyReLU(0.2))) + if chin == chout: + chin = chout // 2 + else: + chout = chin + + self.texbranch1 = nn.Sequential(*(layers[::-1])) + + self.texbranch2 = lin(cm*128*4*4, 256, norm="demod", act=nn.LeakyReLU(0.2)) + self.mu = lin(512, self.latentdim) + self.logstd = lin(512, self.latentdim) + else: + self.mu = lin(256, self.latentdim) + self.logstd = lin(256, self.latentdim) + + def forward(self, verts, texture : Optional[torch.Tensor]=None, losslist : Optional[List[str]]=None): + assert losslist is not None + + x = self.vertbranch(verts.view(verts.size(0), -1)) + if texture is not None: + texture = self.texbranch1(texture).reshape(verts.size(0), -1) + texture = self.texbranch2(texture) + x = torch.cat([x, texture], dim=1) + + mu, logstd = self.mu(x) * 0.1, self.logstd(x) * 0.01 + if self.training: + z = mu + torch.exp(logstd) * torch.randn_like(logstd) + else: + z = mu + + losses = {} + if "kldiv" in losslist: + losses["kldiv"] = torch.mean(-0.5 - logstd + 0.5 * mu ** 2 + 0.5 * torch.exp(2 * logstd), dim=-1) + + return {"encoding": z}, losses diff --git a/dva/mvp/models/encoders/image.py b/dva/mvp/models/encoders/image.py new file mode 100644 index 0000000000000000000000000000000000000000..6da4bb7ff41db43b0bd54a05c2016d33f7aaceeb --- /dev/null +++ b/dva/mvp/models/encoders/image.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +from typing import Optional, List + +import torch +import torch.nn as nn + +from models.utils import LinearELR, Conv2dELR, Downsample2d + +class Encoder(torch.nn.Module): + def __init__(self, ninputs, size, nlayers=7, conv=Conv2dELR, lin=LinearELR): + super(Encoder, self).__init__() + + self.ninputs = ninputs + height, width = size + self.nlayers = nlayers + + ypad = ((height + 2 ** nlayers - 1) // 2 ** nlayers) * 2 ** nlayers - height + xpad = ((width + 2 ** nlayers - 1) // 2 ** nlayers) * 2 ** nlayers - width + self.pad = nn.ZeroPad2d((xpad // 2, xpad - xpad // 2, ypad // 2, ypad - ypad // 2)) + + self.downwidth = ((width + 2 ** nlayers - 1) // 2 ** nlayers) + self.downheight = ((height + 2 ** nlayers - 1) // 2 ** nlayers) + + # compile layers + layers = [] + inch, outch = 3, 64 + for i in range(nlayers): + layers.append(conv(inch, outch, 4, 2, 1, norm="demod", act=nn.LeakyReLU(0.2))) + + if inch == outch: + outch = inch * 2 + else: + inch = outch + if outch > 256: + outch = 256 + + self.down1 = nn.ModuleList([nn.Sequential(*layers) + for i in range(self.ninputs)]) + self.down2 = lin(256 * self.ninputs * self.downwidth * self.downheight, 512, norm="demod", act=nn.LeakyReLU(0.2)) + self.mu = lin(512, 256) + self.logstd = lin(512, 256) + + def forward(self, x, losslist : Optional[List[str]]=None): + assert losslist is not None + + x = self.pad(x) + x = [self.down1[i](x[:, i*3:(i+1)*3, :, :]).view(x.size(0), 256 * self.downwidth * self.downheight) + for i in range(self.ninputs)] + x = torch.cat(x, dim=1) + x = self.down2(x) + + mu, logstd = self.mu(x) * 0.1, self.logstd(x) * 0.01 + if self.training: + z = mu + torch.exp(logstd) * torch.randn(*logstd.size(), device=logstd.device) + else: + z = mu + + losses = {} + if "kldiv" in losslist: + losses["kldiv"] = torch.mean(-0.5 - logstd + 0.5 * mu ** 2 + 0.5 * torch.exp(2 * logstd), dim=-1) + + return {"encoding": z}, losses diff --git a/dva/mvp/models/raymarchers/mvpraymarcher.py b/dva/mvp/models/raymarchers/mvpraymarcher.py new file mode 100644 index 0000000000000000000000000000000000000000..bffb39aa55bfb728a2355c2ea2ab101798d78919 --- /dev/null +++ b/dva/mvp/models/raymarchers/mvpraymarcher.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" Raymarcher for a mixture of volumetric primitives """ +import os +import itertools +import time +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from extensions.mvpraymarch.mvpraymarch import mvpraymarch + +class Raymarcher(nn.Module): + def __init__(self, volradius): + super(Raymarcher, self).__init__() + + self.volradius = volradius + + def forward(self, raypos, raydir, tminmax, decout, + encoding=None, renderoptions={}, trainiter=-1, evaliter=-1, + rayterm=None, + **kwargs): + + # rescale world + dt = renderoptions["dt"] / self.volradius + + rayrgba = mvpraymarch(raypos, raydir, dt, tminmax, + (decout["primpos"], decout["primrot"], decout["primscale"]), + template=decout["template"], + warp=decout["warp"] if "warp" in decout else None, + rayterm=rayterm, + **{k:v for k, v in renderoptions.items() if k in mvpraymarch.__code__.co_varnames}) + + return rayrgba.permute(0, 3, 1, 2), {} diff --git a/dva/mvp/models/raymarchers/stepraymarcher.py b/dva/mvp/models/raymarchers/stepraymarcher.py new file mode 100644 index 0000000000000000000000000000000000000000..0a17ff0ae4c5cf069b5a67dd69c61184ac602478 --- /dev/null +++ b/dva/mvp/models/raymarchers/stepraymarcher.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" Raymarching in pure pytorch """ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Raymarcher(nn.Module): + def __init__(self, volradius): + super(Raymarcher, self).__init__() + + self.volradius = volradius + + def forward(self, raypos, raydir, tminmax, decout, + encoding=None, renderoptions={}, **kwargs): + + dt = renderoptions["dt"] / self.volradius + + tminmax = torch.floor(tminmax / dt) * dt + + t = tminmax[..., 0] + 0. + raypos = raypos + raydir * t[..., None] + + rayrgb = torch.zeros_like(raypos.permute(0, 3, 1, 2)) # NCHW + if "multaccum" in renderoptions and renderoptions["multaccum"]: + lograyalpha = torch.zeros_like(rayrgb[:, 0:1, :, :]) # NCHW + else: + rayalpha = torch.zeros_like(rayrgb[:, 0:1, :, :]) # NCHW + + # raymarch + done = torch.zeros_like(t).bool() + while not done.all(): + valid = torch.prod((raypos > -1.) * (raypos < 1.), dim=-1).float() + samplepos = F.grid_sample(decout["warp"][:, 0], raypos[:, None, :, :, :], align_corners=True).permute(0, 2, 3, 4, 1) + val = F.grid_sample(decout["template"][:, 0], samplepos, align_corners=True)[:, :, 0, :, :] + val = val * valid[:, None, :, :] + sample_rgb, sample_alpha = val[:, :3, :, :], val[:, 3:, :, :] + + done = done | ((t + dt) >= tminmax[..., 1]) + + if "multaccum" in renderoptions and renderoptions["multaccum"]: + contrib = torch.exp(-lograyalpha) * (1. - torch.exp(-sample_alpha * dt)) + + rayrgb = rayrgb + sample_rgb * contrib + lograyalpha = lograyalpha + sample_alpha * dt + else: + contrib = ((rayalpha + sample_alpha * dt).clamp(max=1.) - rayalpha) + + rayrgb = rayrgb + sample_rgb * contrib + rayalpha = rayalpha + contrib + + raypos = raypos + raydir * dt + t = t + dt + + if "multaccum" in renderoptions and renderoptions["multaccum"]: + rayalpha = 1. - torch.exp(-lograyalpha) + + rayrgba = torch.cat([rayrgb, rayalpha], dim=1) + return rayrgba, {} diff --git a/dva/mvp/models/utils.py b/dva/mvp/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ee2d21967a2b9b1a13771de36b2d8691017d81db --- /dev/null +++ b/dva/mvp/models/utils.py @@ -0,0 +1,942 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""PyTorch utilities""" +from collections import OrderedDict +from itertools import islice +import math +import operator +from typing import Optional, Union + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def xaviermultiplier(m, gain): + if isinstance(m, nn.Conv1d): + ksize = m.kernel_size[0] + n1 = m.in_channels + n2 = m.out_channels + + std = gain * math.sqrt(2.0 / ((n1 + n2) * ksize)) + elif isinstance(m, nn.ConvTranspose1d): + ksize = m.kernel_size[0] // m.stride[0] + n1 = m.in_channels + n2 = m.out_channels + + std = gain * math.sqrt(2.0 / ((n1 + n2) * ksize)) + elif isinstance(m, nn.Conv2d): + ksize = m.kernel_size[0] * m.kernel_size[1] + n1 = m.in_channels + n2 = m.out_channels + + std = gain * math.sqrt(2.0 / ((n1 + n2) * ksize)) + elif isinstance(m, nn.ConvTranspose2d): + ksize = m.kernel_size[0] * m.kernel_size[1] // m.stride[0] // m.stride[1] + n1 = m.in_channels + n2 = m.out_channels + + std = gain * math.sqrt(2.0 / ((n1 + n2) * ksize)) + elif isinstance(m, nn.Conv3d): + ksize = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] + n1 = m.in_channels + n2 = m.out_channels + + std = gain * math.sqrt(2.0 / ((n1 + n2) * ksize)) + elif isinstance(m, nn.ConvTranspose3d): + ksize = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] // m.stride[0] // m.stride[1] // m.stride[2] + n1 = m.in_channels + n2 = m.out_channels + + std = gain * math.sqrt(2.0 / ((n1 + n2) * ksize)) + elif isinstance(m, nn.Linear): + n1 = m.in_features + n2 = m.out_features + + std = gain * math.sqrt(2.0 / (n1 + n2)) + else: + return None + + return std + +### normal initialization routines +def xavier_uniform_(m, gain): + std = xaviermultiplier(m, gain) + m.weight.data.uniform_(-std * math.sqrt(3.0), std * math.sqrt(3.0)) + +def initmod(m, gain=1.0, weightinitfunc=xavier_uniform_): + validclasses = [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d] + if any([isinstance(m, x) for x in validclasses]): + weightinitfunc(m, gain) + if hasattr(m, 'bias') and isinstance(m.bias, torch.Tensor): + m.bias.data.zero_() + + # blockwise initialization for transposed convs + if isinstance(m, nn.ConvTranspose2d): + # hardcoded for stride=2 for now + m.weight.data[:, :, 0::2, 1::2] = m.weight.data[:, :, 0::2, 0::2] + m.weight.data[:, :, 1::2, 0::2] = m.weight.data[:, :, 0::2, 0::2] + m.weight.data[:, :, 1::2, 1::2] = m.weight.data[:, :, 0::2, 0::2] + + if isinstance(m, nn.ConvTranspose3d): + # hardcoded for stride=2 for now + m.weight.data[:, :, 0::2, 0::2, 1::2] = m.weight.data[:, :, 0::2, 0::2, 0::2] + m.weight.data[:, :, 0::2, 1::2, 0::2] = m.weight.data[:, :, 0::2, 0::2, 0::2] + m.weight.data[:, :, 0::2, 1::2, 1::2] = m.weight.data[:, :, 0::2, 0::2, 0::2] + m.weight.data[:, :, 1::2, 0::2, 0::2] = m.weight.data[:, :, 0::2, 0::2, 0::2] + m.weight.data[:, :, 1::2, 0::2, 1::2] = m.weight.data[:, :, 0::2, 0::2, 0::2] + m.weight.data[:, :, 1::2, 1::2, 0::2] = m.weight.data[:, :, 0::2, 0::2, 0::2] + m.weight.data[:, :, 1::2, 1::2, 1::2] = m.weight.data[:, :, 0::2, 0::2, 0::2] + + if isinstance(m, Conv2dWNUB) or isinstance(m, Conv2dWN) or isinstance(m, ConvTranspose2dWN) or \ + isinstance(m, ConvTranspose2dWNUB) or isinstance(m, LinearWN): + norm = np.sqrt(torch.sum(m.weight.data[:] ** 2)) + m.g.data[:] = norm + +def initseq(s): + for a, b in zip(s[:-1], s[1:]): + if isinstance(b, nn.ReLU): + initmod(a, nn.init.calculate_gain('relu')) + elif isinstance(b, nn.LeakyReLU): + initmod(a, nn.init.calculate_gain('leaky_relu', b.negative_slope)) + elif isinstance(b, nn.Sigmoid): + initmod(a) + elif isinstance(b, nn.Softplus): + initmod(a) + else: + initmod(a) + + initmod(s[-1]) + +### custom modules +class LinearWN(nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super(LinearWN, self).__init__(in_features, out_features, bias) + self.g = nn.Parameter(torch.ones(out_features)) + self.fused = False + + def fuse(self): + wnorm = torch.sqrt(torch.sum(self.weight ** 2)) + self.weight.data = self.weight.data * self.g.data[:, None] / wnorm + self.fused = True + + def forward(self, input): + if self.fused: + return F.linear(input, self.weight, self.bias) + else: + wnorm = torch.sqrt(torch.sum(self.weight ** 2)) + return F.linear(input, self.weight * self.g[:, None] / wnorm, self.bias) + +class LinearELR(nn.Module): + """Linear layer with equalized learning rate from stylegan2""" + def __init__(self, inch, outch, lrmult=1., norm : Optional[str]=None, act=None): + super(LinearELR, self).__init__() + + # compute gain from activation fn + try: + if isinstance(act, nn.LeakyReLU): + actgain = nn.init.calculate_gain("leaky_relu", act.negative_slope) + elif isinstance(act, nn.ReLU): + actgain = nn.init.calculate_gain("relu") + else: + actgain = nn.init.calculate_gain(act) + except: + actgain = 1. + + initgain = 1. / math.sqrt(inch) + + self.weight = nn.Parameter(torch.randn(outch, inch) / lrmult) + self.weightgain = actgain + + if norm == None: + self.weightgain = self.weightgain * initgain * lrmult + + self.bias = nn.Parameter(torch.full([outch], 0.)) + + self.norm : Optional[str] = norm + self.act = act + + self.fused = False + + def extra_repr(self): + return 'inch={}, outch={}, norm={}, act={}'.format( + self.weight.size(1), self.weight.size(0), self.norm, self.act + ) + + def getweight(self): + if self.fused: + return self.weight + else: + weight = self.weight + if self.norm is not None: + if self.norm == "demod": + weight = F.normalize(weight, dim=1) + return weight + + def fuse(self): + if not self.fused: + with torch.no_grad(): + self.weight.data = self.getweight() * self.weightgain + self.fused = True + + def forward(self, x): + if self.fused: + weight = self.getweight() + + out = torch.addmm(self.bias[None], x, weight.t()) + if self.act is not None: + out = self.act(out) + return out + else: + weight = self.getweight() + + if self.act is None: + out = torch.addmm(self.bias[None], x, weight.t(), alpha=self.weightgain) + return out + else: + out = F.linear(x, weight * self.weightgain, bias=self.bias) + out = self.act(out) + return out + +class Downsample2d(nn.Module): + def __init__(self, nchannels, stride=1, padding=0): + super(Downsample2d, self).__init__() + + self.nchannels = nchannels + self.stride = stride + self.padding = padding + + blurkernel = torch.tensor([1., 6., 15., 20., 15., 6., 1.]) + blurkernel = (blurkernel[:, None] * blurkernel[None, :]) + blurkernel = blurkernel / torch.sum(blurkernel) + blurkernel = blurkernel[None, None, :, :].repeat(nchannels, 1, 1, 1) + self.register_buffer('kernel', blurkernel) + + def forward(self, x): + if self.padding == "reflect": + x = F.pad(x, (3, 3, 3, 3), mode='reflect') + return F.conv2d(x, weight=self.kernel, stride=self.stride, padding=0, groups=self.nchannels) + else: + return F.conv2d(x, weight=self.kernel, stride=self.stride, padding=self.padding, groups=self.nchannels) + +class Dilate2d(nn.Module): + def __init__(self, nchannels, kernelsize, stride=1, padding=0): + super(Dilate2d, self).__init__() + + self.nchannels = nchannels + self.kernelsize = kernelsize + self.stride = stride + self.padding = padding + + blurkernel = torch.ones((self.kernelsize,)) + blurkernel = (blurkernel[:, None] * blurkernel[None, :]) + blurkernel = blurkernel / torch.sum(blurkernel) + blurkernel = blurkernel[None, None, :, :].repeat(nchannels, 1, 1, 1) + self.register_buffer('kernel', blurkernel) + + def forward(self, x): + return F.conv2d(x, weight=self.kernel, stride=self.stride, padding=self.padding, groups=self.nchannels).clamp(max=1.) + +class Conv2dWN(nn.Conv2d): + def __init__(self, in_channels, out_channels, kernel_size, + stride=1, padding=0, dilation=1, groups=1, bias=True): + super(Conv2dWN, self).__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, True) + self.g = nn.Parameter(torch.ones(out_channels)) + + def forward(self, x): + wnorm = torch.sqrt(torch.sum(self.weight ** 2)) + return F.conv2d(x, self.weight * self.g[:, None, None, None] / wnorm, + bias=self.bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups) + +class Conv2dUB(nn.Conv2d): + def __init__(self, in_channels, out_channels, height, width, kernel_size, + stride=1, padding=0, dilation=1, groups=1, bias=False): + super(Conv2dUB, self).__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, False) + self.bias = nn.Parameter(torch.zeros(out_channels, height, width)) + + def forward(self, x): + return F.conv2d(x, self.weight, + bias=None, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups) + self.bias[None, ...] + +class Conv2dWNUB(nn.Conv2d): + def __init__(self, in_channels, out_channels, height, width, kernel_size, + stride=1, padding=0, dilation=1, groups=1, bias=False): + super(Conv2dWNUB, self).__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, False) + self.g = nn.Parameter(torch.ones(out_channels)) + self.bias = nn.Parameter(torch.zeros(out_channels, height, width)) + + def forward(self, x): + wnorm = torch.sqrt(torch.sum(self.weight ** 2)) + return F.conv2d(x, self.weight * self.g[:, None, None, None] / wnorm, + bias=None, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups) + self.bias[None, ...] + +def blockinit(k, stride): + dim = k.ndim - 2 + return k \ + .view(k.size(0), k.size(1), *(x for i in range(dim) for x in (k.size(i+2), 1))) \ + .repeat(1, 1, *(x for i in range(dim) for x in (1, stride))) \ + .view(k.size(0), k.size(1), *(k.size(i+2)*stride for i in range(dim))) + +class ConvTranspose1dELR(nn.Module): + def __init__(self, inch, outch, kernel_size, stride, padding, wsize=0, affinelrmult=1., norm=None, ub=None, act=None): + super(ConvTranspose1dELR, self).__init__() + + self.inch = inch + self.outch = outch + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.wsize = wsize + self.norm = norm + self.ub = ub + self.act = act + + # compute gain from activation fn + try: + if isinstance(act, nn.LeakyReLU): + actgain = nn.init.calculate_gain("leaky_relu", act.negative_slope) + elif isinstance(act, nn.ReLU): + actgain = nn.init.calculate_gain("relu") + else: + actgain = nn.init.calculate_gain(act) + except: + actgain = 1. + + fan_in = inch * (kernel_size / (stride)) + + initgain = stride ** 0.5 if norm == "demod" else 1. / math.sqrt(fan_in) + + self.weightgain = actgain * initgain + + self.weight = nn.Parameter(blockinit( + torch.randn(inch, outch, kernel_size//self.stride), self.stride)) + + if ub is not None: + self.bias = nn.Parameter(torch.zeros(outch, ub[0])) + else: + self.bias = nn.Parameter(torch.zeros(outch)) + + if wsize > 0: + self.affine = LinearELR(wsize, inch, lrmult=affinelrmult) + else: + self.affine = None + + self.fused = False + + def extra_repr(self): + return 'inch={}, outch={}, kernel_size={}, stride={}, padding={}, wsize={}, norm={}, ub={}, act={}'.format( + self.inch, self.outch, self.kernel_size, self.stride, self.padding, self.wsize, self.norm, self.ub, self.act + ) + + def getweight(self, weight): + if self.fused: + return weight + else: + if self.norm is not None: + if self.norm == "demod": + if weight.ndim == 5: + normdims = [1, 3] + else: + normdims = [0, 2] + + if torch.jit.is_scripting(): + # scripting doesn't support F.normalize(..., dim=list[int]) + weight = weight / torch.linalg.vector_norm(weight, dim=normdims, keepdim=True) + else: + weight = F.normalize(weight, dim=normdims) + + weight = weight * self.weightgain + + return weight + + def fuse(self): + if self.affine is None: + with torch.no_grad(): + self.weight.data = self.getweight(self.weight) + self.fused = True + + def forward(self, x, w : Optional[torch.Tensor]=None): + b = x.size(0) + + if self.affine is not None and w is not None: + # modulate + affine = self.affine(w)[:, :, None, None] # [B, inch, 1, 1] + weight = self.weight * (affine * 0.1 + 1.) + else: + weight = self.weight + + weight = self.getweight(weight) + + if self.affine is not None and w is not None: + x = x.view(1, b * self.inch, x.size(2)) + weight = weight.view(b * self.inch, self.outch, self.kernel_size) + groups = b + else: + groups = 1 + + out = F.conv_transpose1d(x, weight, None, + stride=self.stride, padding=self.padding, dilation=1, groups=groups) + + if self.affine is not None and w is not None: + out = out.view(b, self.outch, out.size(2)) + + if self.bias.ndim == 1: + bias = self.bias[None, :, None] + else: + bias = self.bias[None, :, :] + out = out + bias + + if self.act is not None: + out = self.act(out) + + return out + +class ConvTranspose2dELR(nn.Module): + def __init__(self, inch, outch, kernel_size, stride, padding, wsize=0, affinelrmult=1., norm=None, ub=None, act=None): + super(ConvTranspose2dELR, self).__init__() + + self.inch = inch + self.outch = outch + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.wsize = wsize + self.norm = norm + self.ub = ub + self.act = act + + # compute gain from activation fn + try: + if isinstance(act, nn.LeakyReLU): + actgain = nn.init.calculate_gain("leaky_relu", act.negative_slope) + elif isinstance(act, nn.ReLU): + actgain = nn.init.calculate_gain("relu") + else: + actgain = nn.init.calculate_gain(act) + except: + actgain = 1. + + fan_in = inch * (kernel_size ** 2 / (stride ** 2)) + + initgain = stride if norm == "demod" else 1. / math.sqrt(fan_in) + + self.weightgain = actgain * initgain + + self.weight = nn.Parameter(blockinit( + torch.randn(inch, outch, kernel_size//self.stride, kernel_size//self.stride), self.stride)) + + if ub is not None: + self.bias = nn.Parameter(torch.zeros(outch, ub[0], ub[1])) + else: + self.bias = nn.Parameter(torch.zeros(outch)) + + if wsize > 0: + self.affine = LinearELR(wsize, inch, lrmult=affinelrmult) + else: + self.affine = None + + self.fused = False + + def extra_repr(self): + return 'inch={}, outch={}, kernel_size={}, stride={}, padding={}, wsize={}, norm={}, ub={}, act={}'.format( + self.inch, self.outch, self.kernel_size, self.stride, self.padding, self.wsize, self.norm, self.ub, self.act + ) + + def getweight(self, weight): + if self.fused: + return weight + else: + if self.norm is not None: + if self.norm == "demod": + if weight.ndim == 5: + normdims = [1, 3, 4] + else: + normdims = [0, 2, 3] + + if torch.jit.is_scripting(): + # scripting doesn't support F.normalize(..., dim=list[int]) + weight = weight / torch.linalg.vector_norm(weight, dim=normdims, keepdim=True) + else: + weight = F.normalize(weight, dim=normdims) + + weight = weight * self.weightgain + + return weight + + def fuse(self): + if self.affine is None: + with torch.no_grad(): + self.weight.data = self.getweight(self.weight) + self.fused = True + + def forward(self, x, w : Optional[torch.Tensor]=None): + b = x.size(0) + + if self.affine is not None and w is not None: + # modulate + affine = self.affine(w)[:, :, None, None, None] # [B, inch, 1, 1, 1] + weight = self.weight * (affine * 0.1 + 1.) + else: + weight = self.weight + + weight = self.getweight(weight) + + if self.affine is not None and w is not None: + x = x.view(1, b * self.inch, x.size(2), x.size(3)) + weight = weight.view(b * self.inch, self.outch, self.kernel_size, self.kernel_size) + groups = b + else: + groups = 1 + + out = F.conv_transpose2d(x, weight, None, + stride=self.stride, padding=self.padding, dilation=1, groups=groups) + + if self.affine is not None and w is not None: + out = out.view(b, self.outch, out.size(2), out.size(3)) + + if self.bias.ndim == 1: + bias = self.bias[None, :, None, None] + else: + bias = self.bias[None, :, :, :] + out = out + bias + + if self.act is not None: + out = self.act(out) + + return out + +class ConvTranspose3dELR(nn.Module): + def __init__(self, inch, outch, kernel_size, stride, padding, wsize=0, affinelrmult=1., norm=None, ub=None, act=None): + super(ConvTranspose3dELR, self).__init__() + + self.inch = inch + self.outch = outch + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.wsize = wsize + self.norm = norm + self.ub = ub + self.act = act + + # compute gain from activation fn + try: + if isinstance(act, nn.LeakyReLU): + actgain = nn.init.calculate_gain("leaky_relu", act.negative_slope) + elif isinstance(act, nn.ReLU): + actgain = nn.init.calculate_gain("relu") + else: + actgain = nn.init.calculate_gain(act) + except: + actgain = 1. + + fan_in = inch * (kernel_size ** 3 / (stride ** 3)) + + initgain = stride ** 1.5 if norm == "demod" else 1. / math.sqrt(fan_in) + + self.weightgain = actgain * initgain + + self.weight = nn.Parameter(blockinit( + torch.randn(inch, outch, kernel_size//self.stride, kernel_size//self.stride, kernel_size//self.stride), self.stride)) + + if ub is not None: + self.bias = nn.Parameter(torch.zeros(outch, ub[0], ub[1], ub[2])) + else: + self.bias = nn.Parameter(torch.zeros(outch)) + + if wsize > 0: + self.affine = LinearELR(wsize, inch, lrmult=affinelrmult) + else: + self.affine = None + + self.fused = False + + def extra_repr(self): + return 'inch={}, outch={}, kernel_size={}, stride={}, padding={}, wsize={}, norm={}, ub={}, act={}'.format( + self.inch, self.outch, self.kernel_size, self.stride, self.padding, self.wsize, self.norm, self.ub, self.act + ) + + def getweight(self, weight): + if self.fused: + return weight + else: + if self.norm is not None: + if self.norm == "demod": + if weight.ndim == 5: + normdims = [1, 3, 4, 5] + else: + normdims = [0, 2, 3, 4] + + if torch.jit.is_scripting(): + # scripting doesn't support F.normalize(..., dim=list[int]) + weight = weight / torch.linalg.vector_norm(weight, dim=normdims, keepdim=True) + else: + weight = F.normalize(weight, dim=normdims) + + weight = weight * self.weightgain + + return weight + + def fuse(self): + if self.affine is None: + with torch.no_grad(): + self.weight.data = self.getweight(self.weight) + self.fused = True + + def forward(self, x, w : Optional[torch.Tensor]=None): + b = x.size(0) + + if self.affine is not None and w is not None: + # modulate + affine = self.affine(w)[:, :, None, None, None, None] # [B, inch, 1, 1, 1, 1] + weight = self.weight * (affine * 0.1 + 1.) + else: + weight = self.weight + + weight = self.getweight(weight) + + if self.affine is not None and w is not None: + x = x.view(1, b * self.inch, x.size(2), x.size(3), x.size(4)) + weight = weight.view(b * self.inch, self.outch, self.kernel_size, self.kernel_size, self.kernel_size) + groups = b + else: + groups = 1 + + out = F.conv_transpose3d(x, weight, None, + stride=self.stride, padding=self.padding, dilation=1, groups=groups) + + if self.affine is not None and w is not None: + out = out.view(b, self.outch, out.size(2), out.size(3), out.size(4)) + + if self.bias.ndim == 1: + bias = self.bias[None, :, None, None, None] + else: + bias = self.bias[None, :, :, :, :] + out = out + bias + + if self.act is not None: + out = self.act(out) + + return out + +class Conv2dELR(nn.Module): + def __init__(self, inch, outch, kernel_size, stride, padding, wsize=0, affinelrmult=1., norm=None, ub=None, act=None): + super(Conv2dELR, self).__init__() + + self.inch = inch + self.outch = outch + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.wsize = wsize + self.norm = norm + self.ub = ub + self.act = act + + # compute gain from activation fn + try: + if isinstance(act, nn.LeakyReLU): + actgain = nn.init.calculate_gain("leaky_relu", act.negative_slope) + elif isinstance(act, nn.ReLU): + actgain = nn.init.calculate_gain("relu") + else: + actgain = nn.init.calculate_gain(act) + except: + actgain = 1. + + fan_in = inch * (kernel_size ** 2) + + initgain = 1. if norm == "demod" else 1. / math.sqrt(fan_in) + + self.weightgain = actgain * initgain + + self.weight = nn.Parameter( + torch.randn(outch, inch, kernel_size, kernel_size)) + + if ub is not None: + self.bias = nn.Parameter(torch.zeros(outch, ub[0], ub[1])) + else: + self.bias = nn.Parameter(torch.zeros(outch)) + + if wsize > 0: + self.affine = LinearELR(wsize, inch, lrmult=affinelrmult) + else: + self.affine = None + + self.fused = False + + def extra_repr(self): + return 'inch={}, outch={}, kernel_size={}, stride={}, padding={}, wsize={}, norm={}, ub={}, act={}'.format( + self.inch, self.outch, self.kernel_size, self.stride, self.padding, self.wsize, self.norm, self.ub, self.act + ) + + def getweight(self, weight): + if self.fused: + return weight + else: + if self.norm is not None: + if self.norm == "demod": + if weight.ndim == 5: + normdims = [2, 3, 4] + else: + normdims = [1, 2, 3] + + if torch.jit.is_scripting(): + # scripting doesn't support F.normalize(..., dim=list[int]) + weight = weight / torch.linalg.vector_norm(weight, dim=normdims, keepdim=True) + else: + weight = F.normalize(weight, dim=normdims) + + weight = weight * self.weightgain + + return weight + + def fuse(self): + if self.affine is None: + with torch.no_grad(): + self.weight.data = self.getweight(self.weight) + self.fused = True + + def forward(self, x, w : Optional[torch.Tensor]=None): + b = x.size(0) + + if self.affine is not None and w is not None: + # modulate + affine = self.affine(w)[:, None, :, None, None] # [B, 1, inch, 1, 1] + weight = self.weight * (affine * 0.1 + 1.) + else: + weight = self.weight + + weight = self.getweight(weight) + + if self.affine is not None and w is not None: + x = x.view(1, b * self.inch, x.size(2), x.size(3)) + weight = weight.view(b * self.outch, self.inch, self.kernel_size, self.kernel_size) + groups = b + else: + groups = 1 + + out = F.conv2d(x, weight, None, + stride=self.stride, padding=self.padding, dilation=1, groups=groups) + + if self.affine is not None and w is not None: + out = out.view(b, self.outch, out.size(2), out.size(3)) + + if self.bias.ndim == 1: + bias = self.bias[None, :, None, None] + else: + bias = self.bias[None, :, :, :] + out = out + bias + + if self.act is not None: + out = self.act(out) + + return out + +class ConvTranspose2dWN(nn.ConvTranspose2d): + def __init__(self, in_channels, out_channels, kernel_size, + stride=1, padding=0, dilation=1, groups=1, bias=True): + super(ConvTranspose2dWN, self).__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, True) + self.g = nn.Parameter(torch.ones(out_channels)) + self.fused = False + + def fuse(self): + wnorm = torch.sqrt(torch.sum(self.weight ** 2)) + self.weight.data = self.weight.data * self.g.data[None, :, None, None] / wnorm + self.fused = True + + def forward(self, x): + bias = self.bias + assert bias is not None + if self.fused: + return F.conv_transpose2d(x, self.weight, + bias=self.bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups) + else: + wnorm = torch.sqrt(torch.sum(self.weight ** 2)) + return F.conv_transpose2d(x, self.weight * self.g[None, :, None, None] / wnorm, + bias=self.bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups) + +class ConvTranspose2dUB(nn.ConvTranspose2d): + def __init__(self, width, height, in_channels, out_channels, kernel_size, + stride=1, padding=0, dilation=1, groups=1, bias=False): + super(ConvTranspose2dUB, self).__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, False) + self.bias_ = nn.Parameter(torch.zeros(out_channels, height, width)) + + def forward(self, x): + return F.conv_transpose2d(x, self.weight, + bias=None, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups) + self.bias_[None, ...] + +class ConvTranspose2dWNUB(nn.ConvTranspose2d): + def __init__(self, in_channels, out_channels, height, width, kernel_size, + stride=1, padding=0, dilation=1, groups=1, bias=False): + super(ConvTranspose2dWNUB, self).__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, False) + self.g = nn.Parameter(torch.ones(out_channels)) + self.bias = nn.Parameter(torch.zeros(out_channels, height, width)) + #self.biasf = nn.Parameter(torch.zeros(out_channels, height, width)) + self.fused = False + + def fuse(self): + wnorm = torch.sqrt(torch.sum(self.weight ** 2)) + self.weight.data = self.weight.data * self.g.data[None, :, None, None] / wnorm + self.fused = True + + def forward(self, x): + bias = self.bias + assert bias is not None + if self.fused: + return F.conv_transpose2d(x, self.weight, + bias=None, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups) + bias[None, ...] + else: + wnorm = torch.sqrt(torch.sum(self.weight ** 2)) + return F.conv_transpose2d(x, self.weight * self.g[None, :, None, None] / wnorm, + bias=None, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups) + bias[None, ...] + +class Conv3dUB(nn.Conv3d): + def __init__(self, width, height, depth, in_channels, out_channels, kernel_size, + stride=1, padding=0, dilation=1, groups=1, bias=True): + super(Conv3dUB, self).__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, False) + self.bias = nn.Parameter(torch.zeros(out_channels, depth, height, width)) + + def forward(self, x): + return F.conv3d(x, self.weight, + bias=None, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups) + self.bias[None, ...] + +class ConvTranspose3dUB(nn.ConvTranspose3d): + def __init__(self, width, height, depth, in_channels, out_channels, kernel_size, + stride=1, padding=0, dilation=1, groups=1, bias=True): + super(ConvTranspose3dUB, self).__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, False) + self.bias = nn.Parameter(torch.zeros(out_channels, depth, height, width)) + + def forward(self, x): + return F.conv_transpose3d(x, self.weight, + bias=None, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups) + self.bias[None, ...] + +class Rodrigues(nn.Module): + def __init__(self): + super(Rodrigues, self).__init__() + + def forward(self, rvec): + theta = torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=1)) + rvec = rvec / theta[:, None] + costh = torch.cos(theta) + sinth = torch.sin(theta) + return torch.stack(( + rvec[:, 0] ** 2 + (1. - rvec[:, 0] ** 2) * costh, + rvec[:, 0] * rvec[:, 1] * (1. - costh) - rvec[:, 2] * sinth, + rvec[:, 0] * rvec[:, 2] * (1. - costh) + rvec[:, 1] * sinth, + + rvec[:, 0] * rvec[:, 1] * (1. - costh) + rvec[:, 2] * sinth, + rvec[:, 1] ** 2 + (1. - rvec[:, 1] ** 2) * costh, + rvec[:, 1] * rvec[:, 2] * (1. - costh) - rvec[:, 0] * sinth, + + rvec[:, 0] * rvec[:, 2] * (1. - costh) - rvec[:, 1] * sinth, + rvec[:, 1] * rvec[:, 2] * (1. - costh) + rvec[:, 0] * sinth, + rvec[:, 2] ** 2 + (1. - rvec[:, 2] ** 2) * costh), dim=1).view(-1, 3, 3) + +class Quaternion(nn.Module): + def __init__(self): + super(Quaternion, self).__init__() + + def forward(self, rvec): + theta = torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=1)) + rvec = rvec / theta[:, None] + return torch.stack(( + 1. - 2. * rvec[:, 1] ** 2 - 2. * rvec[:, 2] ** 2, + 2. * (rvec[:, 0] * rvec[:, 1] - rvec[:, 2] * rvec[:, 3]), + 2. * (rvec[:, 0] * rvec[:, 2] + rvec[:, 1] * rvec[:, 3]), + + 2. * (rvec[:, 0] * rvec[:, 1] + rvec[:, 2] * rvec[:, 3]), + 1. - 2. * rvec[:, 0] ** 2 - 2. * rvec[:, 2] ** 2, + 2. * (rvec[:, 1] * rvec[:, 2] - rvec[:, 0] * rvec[:, 3]), + + 2. * (rvec[:, 0] * rvec[:, 2] - rvec[:, 1] * rvec[:, 3]), + 2. * (rvec[:, 0] * rvec[:, 3] + rvec[:, 1] * rvec[:, 2]), + 1. - 2. * rvec[:, 0] ** 2 - 2. * rvec[:, 1] ** 2 + ), dim=1).view(-1, 3, 3) + +class BufferDict(nn.Module): + def __init__(self, d, persistent=False): + super(BufferDict, self).__init__() + + for k in d: + self.register_buffer(k, d[k], persistent=False) + + def __getitem__(self, key): + return self._buffers[key] + + def __setitem__(self, key, parameter): + self.register_buffer(key, parameter, persistent=False) + +def matrix_to_axisangle(r): + th = torch.arccos(0.5 * (r[..., 0, 0] + r[..., 1, 1] + r[..., 2, 2] - 1.))[..., None] + vec = 0.5 * torch.stack([ + r[..., 2, 1] - r[..., 1, 2], + r[..., 0, 2] - r[..., 2, 0], + r[..., 1, 0] - r[..., 0, 1]], dim=-1) / torch.sin(th) + return th, vec + +@torch.jit.script +def axisangle_to_matrix(rvec : torch.Tensor): + theta = torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=-1)) + rvec = rvec / theta[..., None] + costh = torch.cos(theta) + sinth = torch.sin(theta) + return torch.stack(( + torch.stack((rvec[..., 0] ** 2 + (1. - rvec[..., 0] ** 2) * costh, + rvec[..., 0] * rvec[..., 1] * (1. - costh) - rvec[..., 2] * sinth, + rvec[..., 0] * rvec[..., 2] * (1. - costh) + rvec[..., 1] * sinth), dim=-1), + + torch.stack((rvec[..., 0] * rvec[..., 1] * (1. - costh) + rvec[..., 2] * sinth, + rvec[..., 1] ** 2 + (1. - rvec[..., 1] ** 2) * costh, + rvec[..., 1] * rvec[..., 2] * (1. - costh) - rvec[..., 0] * sinth), dim=-1), + + torch.stack((rvec[..., 0] * rvec[..., 2] * (1. - costh) - rvec[..., 1] * sinth, + rvec[..., 1] * rvec[..., 2] * (1. - costh) + rvec[..., 0] * sinth, + rvec[..., 2] ** 2 + (1. - rvec[..., 2] ** 2) * costh), dim=-1)), + dim=-2) + +def rotation_interp(r0, r1, alpha): + r0a = r0.view(-1, 3, 3) + r1a = r1.view(-1, 3, 3) + r = torch.bmm(r0a.permute(0, 2, 1), r1a).view_as(r0) + + th, rvec = matrix_to_axisangle(r) + rvec = rvec * (alpha * th) + + r = axisangle_to_matrix(rvec) + return torch.bmm(r0a, r.view(-1, 3, 3)).view_as(r0) + +def fuse(trainiter=None, renderoptions={}): + def _fuse(m): + if hasattr(m, "fuse") and isinstance(m, torch.nn.Module): + if m.fuse.__code__.co_argcount > 1: + m.fuse(trainiter, renderoptions) + else: + m.fuse() + return _fuse + +def no_grad(m): + for p in m.parameters(): + p.requires_grad = False diff --git a/dva/mvp/models/volumetric.py b/dva/mvp/models/volumetric.py new file mode 100644 index 0000000000000000000000000000000000000000..bd09ff0bb12a2e2a6b6dcf879718f09b0b63fc01 --- /dev/null +++ b/dva/mvp/models/volumetric.py @@ -0,0 +1,383 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" Volumetric autoencoder (image -> encoding -> volume -> image) """ +import inspect +import time +from typing import Optional + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import models.utils + +from extensions.utils.utils import compute_raydirs + +@torch.jit.script +def compute_raydirs_ref(pixelcoords : torch.Tensor, viewrot : torch.Tensor, focal : torch.Tensor, princpt : torch.Tensor): + raydir = (pixelcoords - princpt[:, None, None, :]) / focal[:, None, None, :] + raydir = torch.cat([raydir, torch.ones_like(raydir[:, :, :, 0:1])], dim=-1) + raydir = torch.sum(viewrot[:, None, None, :, :] * raydir[:, :, :, :, None], dim=-2) + raydir = F.normalize(raydir, dim=-1) + + return raydir + +@torch.jit.script +def compute_rmbounds(viewpos : torch.Tensor, raydir : torch.Tensor, volradius : float): + viewpos = viewpos / volradius + + # compute raymarching starting points + with torch.no_grad(): + t1 = (-1. - viewpos[:, None, None, :]) / raydir + t2 = ( 1. - viewpos[:, None, None, :]) / raydir + tmin = torch.max(torch.min(t1[..., 0], t2[..., 0]), + torch.max(torch.min(t1[..., 1], t2[..., 1]), + torch.min(t1[..., 2], t2[..., 2]))) + tmax = torch.min(torch.max(t1[..., 0], t2[..., 0]), + torch.min(torch.max(t1[..., 1], t2[..., 1]), + torch.max(t1[..., 2], t2[..., 2]))) + + intersections = tmin < tmax + t = torch.where(intersections, tmin, torch.zeros_like(tmin)).clamp(min=0.) + tmin = torch.where(intersections, tmin, torch.zeros_like(tmin)).clamp(min=0.) + tmax = torch.where(intersections, tmax, torch.zeros_like(tmin)) + + raypos = viewpos[:, None, None, :] + raydir * 0. + tminmax = torch.stack([tmin, tmax], dim=-1) + + return raypos, tminmax + +class Autoencoder(nn.Module): + def __init__(self, dataset, encoder, decoder, raymarcher, colorcal, + volradius, bgmodel=None, encoderinputs=[], topology=None, + imagemean=0., imagestd=1., vertmask=None, cudaraydirs=True): + super(Autoencoder, self).__init__() + + self.encoder = encoder + self.decoder = decoder + self.raymarcher = raymarcher + self.colorcal = colorcal + self.volradius = volradius + self.bgmodel = bgmodel + self.encoderinputs = encoderinputs + + if hasattr(dataset, 'vertmean'): + self.register_buffer("vertmean", torch.from_numpy(dataset.vertmean), persistent=False) + self.vertstd = dataset.vertstd + if hasattr(dataset, 'texmean'): + self.register_buffer("texmean", torch.from_numpy(dataset.texmean), persistent=False) + self.texstd = dataset.texstd + self.imagemean = imagemean + self.imagestd = imagestd + + self.cudaraydirs = cudaraydirs + + if vertmask is not None: + self.register_buffer("vertmask", torch.from_numpy(vertmask), persistent=False) + + self.irgbmsestart = -1 + + def forward(self, + camrot : torch.Tensor, + campos : torch.Tensor, + focal : torch.Tensor, + princpt : torch.Tensor, + camindex : Optional[torch.Tensor] = None, + pixelcoords : Optional[torch.Tensor]=None, + modelmatrix : Optional[torch.Tensor]=None, + modelmatrixinv : Optional[torch.Tensor]=None, + modelmatrix_next : Optional[torch.Tensor]=None, + modelmatrixinv_next : Optional[torch.Tensor]=None, + validinput : Optional[torch.Tensor]=None, + avgtex : Optional[torch.Tensor]=None, + avgtex_next : Optional[torch.Tensor]=None, + verts : Optional[torch.Tensor]=None, + verts_next : Optional[torch.Tensor]=None, + fixedcamimage : Optional[torch.Tensor]=None, + encoding : Optional[torch.Tensor]=None, + image : Optional[torch.Tensor]=None, + imagemask : Optional[torch.Tensor]=None, + imagevalid : Optional[torch.Tensor]=None, + bg : Optional[torch.Tensor]=None, + renderoptions : dict ={}, + trainiter : int=-1, + evaliter : Optional[torch.Tensor]=None, + outputlist : list=[], + losslist : list=[], + **kwargs): + """ + Parameters + ---------- + camrot : torch.Tensor [B, 3, 3] + Rotation matrix of target view camera + campos : torch.Tensor [B, 3] + Position of target view camera + focal : torch.Tensor [B, 2] + Focal length of target view camera + princpt : torch.Tensor [B, 2] + Princple point of target view camera + camindex : torch.Tensor[int32], optional [B] + Camera index within the list of all cameras + pixelcoords : torch.Tensor, optional [B, H', W', 2] + Pixel coordinates to render of the target view camera + modelmatrix : torch.Tensor, optional [B, 3, 3] + Relative transform from the 'neutral' pose of object + validinput : torch.Tensor, optional [B] + Whether the current batch element is valid (used for missing images) + avgtex : torch.Tensor, optional [B, 3, 1024, 1024] + Texture map averaged from all viewpoints + verts : torch.Tensor, optional [B, 7306, 3] + Mesh vertex positions + fixedcamimage : torch.Tensor, optional [B, 3, 512, 334] + Camera images from a one or more cameras that are always the same + (i.e., unrelated to target) + encoding : torch.Tensor, optional [B, 256] + Direct encodings (overrides encoder) + image : torch.Tensor, optional [B, 3, H, W] + Target image + imagemask : torch.Tensor, optional [B, 1, H, W] + Target image mask for computing reconstruction loss + imagevalid : torch.Tensor, optional [B] + bg : torch.Tensor, optional [B, 3, H, W] + renderoptions : dict + Rendering/raymarching options (e.g., stepsize, whether to output debug images, etc.) + trainiter : int + Training iteration number + outputlist : list + Values to return (e.g., image reconstruction, debug output) + losslist : list + Losses to output (e.g., image reconstruction loss, priors) + + Returns + ------- + result : dict + Contains outputs specified in outputlist (e.g., image rgb + reconstruction "irgbrec") + losses : dict + Losses to optimize + """ + resultout = {} + resultlosses = {} + + aestart = time.time() + + # encode/get encoding + # verts [6, 7306, 3] + # avgtex [6, 3, 256, 256] + if encoding is None: + if "enctime" in outputlist: + torch.cuda.synchronize() + encstart = time.time() + encout, enclosses = self.encoder( + *[dict(verts=verts, avgtex=avgtex, fixedcamimage=fixedcamimage)[k] for k in self.encoderinputs], + losslist=losslist) + if "enctime" in outputlist: + torch.cuda.synchronize() + encend = time.time() + resultout["enctime"] = encend - encstart + + # encoding [6, 256] + encoding = encout["encoding"] + resultlosses.update(enclosses) + + # compute relative viewing position + if modelmatrixinv is not None: + viewrot = torch.bmm(camrot, modelmatrixinv[:, :3, :3]) + viewpos = torch.bmm((campos[:, :] - modelmatrixinv[:, :3, 3])[:, None, :], modelmatrixinv[:, :3, :3])[:, 0, :] + else: + viewrot = camrot + viewpos = campos + + # decode volumetric representation + if "dectime" in outputlist: + torch.cuda.synchronize() + decstart = time.time() + if isinstance(self.decoder, torch.jit.ScriptModule): + # torchscript requires statically typed dict + renderoptionstyped : Dict[str, str] = {k: str(v) for k, v in renderoptions.items()} + else: + renderoptionstyped = renderoptions + decout, declosses = self.decoder( + encoding, + viewpos, + renderoptions=renderoptionstyped, + trainiter=trainiter, + evaliter=evaliter, + losslist=losslist) + if "dectime" in outputlist: + torch.cuda.synchronize() + decend = time.time() + resultout["dectime"] = decend - decstart + resultlosses.update(declosses) + + # compute vertex loss + if "vertmse" in losslist: + weight = validinput[:, None, None].expand_as(verts) + + if hasattr(self, "vertmask"): + weight = weight * self.vertmask[None, :, None] + + vertsrecstd = (decout["verts"] - self.vertmean) / self.vertstd + + vertsqerr = weight * (verts - vertsrecstd) ** 2 + + vertmse = torch.sum(vertsqerr.view(vertsqerr.size(0), -1), dim=-1) + vertmse_weight = torch.sum(weight.view(weight.size(0), -1), dim=-1) + + resultlosses["vertmse"] = (vertmse, vertmse_weight) + + # compute texture loss + if "trgbmse" in losslist or "trgbsqerr" in outputlist: + weight = (validinput[:, None, None, None] * texmask[:, None, :, :].float()).expand_as(tex).contiguous() + + # re-standardize + texrecstd = (decout["tex"] - self.texmean.to("cuda")) / self.texstd + texstd = (tex - self.texmean.to("cuda")) / self.texstd + + texsqerr = weight * (texstd - texrecstd) ** 2 + + if "trgbsqerr" in outputlist: + resultout["trgbsqerr"] = texsqerr + + # texture rgb mean-squared-error + if "trgbmse" in losslist: + texmse = torch.sum(texsqerr.view(texsqerr.size(0), -1), dim=-1) + texmse_weight = torch.sum(weight.view(weight.size(0), -1), dim=-1) + + resultlosses["trgbmse"] = (texmse, texmse_weight) + + # subsample depth, imagerec, imagerecmask + if image is not None and pixelcoords.size()[1:3] != image.size()[2:4]: + imagesize = torch.tensor(image.size()[3:1:-1], dtype=torch.float32, device=pixelcoords.device) + else: + imagesize = torch.tensor(pixelcoords.size()[2:0:-1], dtype=torch.float32, device=pixelcoords.device) + + samplecoords = pixelcoords * 2. / (imagesize[None, None, None, :] - 1.) - 1. + + # compute ray directions + if self.cudaraydirs: + raypos, raydir, tminmax = compute_raydirs(viewpos, viewrot, focal, princpt, pixelcoords, self.volradius) + else: + raydir = compute_raydirs_ref(pixelcoords, viewrot, focal, princpt) + raypos, tminmax = compute_rmbounds(viewpos, raydir, self.volradius) + + if "dtstd" in renderoptions: + renderoptions["dt"] = renderoptions["dt"] * \ + torch.exp(torch.randn(1) * renderoptions.get("dtstd")).item() + + if renderoptions.get("unbiastminmax", False): + stepsize = renderoptions["dt"] / self.volradius + tminmax = torch.floor(tminmax / stepsize) * stepsize + + if renderoptions.get("tminmaxblocks", False): + bx, by = renderoptions.get("blocksize", (8, 16)) + H, W = tminmax.size(1), tminmax.size(2) + tminmax = tminmax.view(tminmax.size(0), H // by, by, W // bx, bx, 2) + tminmax = tminmax.amin(dim=[2, 4], keepdim=True) + tminmax = tminmax.repeat(1, 1, by, 1, bx, 1) + tminmax = tminmax.view(tminmax.size(0), H, W, 2) + + # raymarch + if "rmtime" in outputlist: + torch.cuda.synchronize() + rmstart = time.time() + # rayrgba [6, 4, 384, 384] + rayrgba, rmlosses = self.raymarcher(raypos, raydir, tminmax, + decout=decout, renderoptions=renderoptions, + trainiter=trainiter, evaliter=evaliter, losslist=losslist) + resultlosses.update(rmlosses) + if "rmtime" in outputlist: + torch.cuda.synchronize() + rmend = time.time() + resultout["rmtime"] = rmend - rmstart + + if isinstance(rayrgba, tuple): + rayrgb, rayalpha = rayrgba + else: + rayrgb, rayalpha = rayrgba[:, :3, :, :].contiguous(), rayrgba[:, 3:4, :, :].contiguous() + + # beta distribution prior on final opacity + if "alphapr" in losslist: + alphaprior = torch.mean( + torch.log(0.1 + rayalpha.view(rayalpha.size(0), -1)) + + torch.log(0.1 + 1. - rayalpha.view(rayalpha.size(0), -1)) - -2.20727, dim=-1) + resultlosses["alphapr"] = alphaprior + + # color correction + if camindex is not None and not renderoptions.get("nocolcorrect", False): + rayrgb = self.colorcal(rayrgb, camindex) + + # background decoder + if self.bgmodel is not None and not renderoptions.get("nobg", False): + if "bgtime" in outputlist: + torch.cuda.synchronize() + bgstart = time.time() + + raypos, raydir, tminmax = compute_raydirs(campos, camrot, focal, princpt, pixelcoords, self.volradius) + + rayposbeg = raypos + raydir * tminmax[..., 0:1] + rayposend = raypos + raydir * tminmax[..., 1:2] + + bg = self.bgmodel(bg, camindex, campos, rayposend, raydir, samplecoords, trainiter=trainiter) + + # alpha matting + if bg is not None: + rayrgb = rayrgb + (1. - rayalpha) * bg + + if "bg" in outputlist: + resultout["bg"] = bg + + if "bgtime" in outputlist: + torch.cuda.synchronize() + bgend = time.time() + resultout["bgtime"] = bgend - bgstart + + if "irgbrec" in outputlist: + resultout["irgbrec"] = rayrgb + if "irgbarec" in outputlist: + resultout["irgbarec"] = torch.cat([rayrgb, rayalpha], dim=1) + if "irgbflip" in outputlist: + resultout["irgbflip"] = torch.cat([rayrgb[i:i+1] if i % 4 < 2 else image[i:i+1] + for i in range(image.size(0))], dim=0) + + # image rgb loss + if image is not None and trainiter > self.irgbmsestart: + # subsample image + if pixelcoords.size()[1:3] != image.size()[2:4]: + image = F.grid_sample(image, samplecoords, align_corners=True) + if imagemask is not None: + imagemask = F.grid_sample(imagemask, samplecoords, align_corners=True) + + # compute reconstruction loss weighting + weight = torch.ones_like(image) * validinput[:, None, None, None] + if imagevalid is not None: + weight = weight * imagevalid[:, None, None, None] + if imagemask is not None: + weight = weight * imagemask + + if "irgbsqerr" in outputlist: + irgbsqerr_nonorm = (weight * (image - rayrgb) ** 2).contiguous() + resultout["irgbsqerr"] = torch.sqrt(irgbsqerr_nonorm.mean(dim=1, keepdim=True)) + + # standardize + rayrgb = (rayrgb - self.imagemean) / self.imagestd + image = (image - self.imagemean) / self.imagestd + + irgbsqerr = (weight * (image - rayrgb) ** 2).contiguous() + + if "irgbmse" in losslist: + irgbmse = torch.sum(irgbsqerr.view(irgbsqerr.size(0), -1), dim=-1) + irgbmse_weight = torch.sum(weight.view(weight.size(0), -1), dim=-1) + + resultlosses["irgbmse"] = (irgbmse, irgbmse_weight) + + aeend = time.time() + if "aetime" in outputlist: + resultout["aetime"] = aeend - aestart + + return resultout, resultlosses diff --git a/dva/ray_marcher.py b/dva/ray_marcher.py new file mode 100644 index 0000000000000000000000000000000000000000..482b3c122fa4a312b14757af34be84f032c08dab --- /dev/null +++ b/dva/ray_marcher.py @@ -0,0 +1,279 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Dict, Tuple +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +import random + +from dva.mvp.extensions.mvpraymarch.mvpraymarch import mvpraymarch +from dva.mvp.extensions.utils.utils import compute_raydirs + +import logging + +logger = logging.getLogger(__name__) + + +def convert_camera_parameters(Rt, K): + R = Rt[:, :3, :3] + t = -R.permute(0, 2, 1).bmm(Rt[:, :3, 3].unsqueeze(2)).squeeze(2) + return dict( + campos=t, + camrot=R, + focal=K[:, :2, :2], + princpt=K[:, :2, 2], + ) + +def subsample_pixel_coords( + pixel_coords: th.Tensor, batch_size: int, ray_subsample_factor: int = 4 +): + + H, W = pixel_coords.shape[:2] + SW = W // ray_subsample_factor + SH = H // ray_subsample_factor + + all_coords = [] + for _ in range(batch_size): + # TODO: this is ugly, switch to pytorch? + x0 = th.randint(0, ray_subsample_factor - 1, size=()) + y0 = th.randint(0, ray_subsample_factor - 1, size=()) + dx = ray_subsample_factor + dy = ray_subsample_factor + x1 = x0 + dx * SW + y1 = y0 + dy * SH + all_coords.append(pixel_coords[y0:y1:dy, x0:x1:dx, :]) + all_coords = th.stack(all_coords, dim=0) + return all_coords + + +def resize_pixel_coords( + pixel_coords: th.Tensor, batch_size: int, ray_subsample_factor: int = 4 +): + + H, W = pixel_coords.shape[:2] + SW = W // ray_subsample_factor + SH = H // ray_subsample_factor + + all_coords = [] + for _ in range(batch_size): + # TODO: this is ugly, switch to pytorch? + x0, y0 = ray_subsample_factor // 2, ray_subsample_factor // 2 + dx = ray_subsample_factor + dy = ray_subsample_factor + x1 = x0 + dx * SW + y1 = y0 + dy * SH + all_coords.append(pixel_coords[y0:y1:dy, x0:x1:dx, :]) + all_coords = th.stack(all_coords, dim=0) + return all_coords + + +class RayMarcher(nn.Module): + def __init__( + self, + image_height, + image_width, + volradius, + fadescale=8.0, + fadeexp=8.0, + dt=1.0, + ray_subsample_factor=1, + accum=2, + termthresh=0.99, + blocksize=None, + with_t_img=True, + chlast=False, + assets=None, + ): + super().__init__() + + # TODO: add config? + self.image_height = image_height + self.image_width = image_width + self.volradius = volradius + self.dt = dt + + self.fadescale = fadescale + self.fadeexp = fadeexp + + # NOTE: this seems to not work for other configs? + if blocksize is None: + blocksize = (8, 16) + + self.blocksize = blocksize + self.with_t_img = with_t_img + self.chlast = chlast + + self.accum = accum + self.termthresh = termthresh + + base_pixel_coords = th.stack( + th.meshgrid( + th.arange(self.image_height, dtype=th.float32), + th.arange(self.image_width, dtype=th.float32), + )[::-1], + dim=-1, + ) + self.register_buffer("base_pixel_coords", base_pixel_coords, persistent=False) + self.fixed_bvh_cache = {-1: (th.empty(0), th.empty(0), th.empty(0))} + self.ray_subsample_factor = ray_subsample_factor + + def _set_pix_coords(self): + dev = self.base_pixel_coords.device + self.base_pixel_coords = th.stack( + th.meshgrid( + th.arange(self.image_height, dtype=th.float32, device=dev), + th.arange(self.image_width, dtype=th.float32, device=dev), + )[::-1], + dim=-1, + ) + + def resize(self, h: int, w: int): + self.image_height = h + self.image_width = w + + self._set_pix_coords() + + def forward( + self, + prim_rgba: th.Tensor, + prim_pos: th.Tensor, + prim_rot: th.Tensor, + prim_scale: th.Tensor, + K: th.Tensor, + RT: th.Tensor, + ray_subsample_factor: Optional[int] = None, + ): + """ + Args: + prim_rgba: primitive payload [B, K, 4, S, S, S], + K - # of primitives, S - primitive size + prim_pos: locations [B, K, 3] + prim_rot: rotations [B, K, 3, 3] + prim_scale: scales [B, K, 3] + K: intrinsics [B, 3, 3] + RT: extrinsics [B, 3, 4] + Returns: + a dict of tensors + """ + # TODO: maybe we can re-use mvpraymarcher? + B = prim_rgba.shape[0] + device = prim_rgba.device + + # TODO: this should return focal 2x2? + camera = convert_camera_parameters(RT, K) + camera = {k: v.contiguous() for k, v in camera.items()} + + dt = self.dt / self.volradius + + if ray_subsample_factor is None: + ray_subsample_factor = self.ray_subsample_factor + + if ray_subsample_factor > 1 and self.training: + pixel_coords = subsample_pixel_coords( + self.base_pixel_coords, int(B), ray_subsample_factor + ) + elif ray_subsample_factor > 1: + pixel_coords = resize_pixel_coords( + self.base_pixel_coords, + int(B), + ray_subsample_factor, + ) + else: + pixel_coords = ( + self.base_pixel_coords[np.newaxis].expand(B, -1, -1, -1).contiguous() + ) + + prim_pos = prim_pos / self.volradius + + focal = th.diagonal(camera["focal"], dim1=1, dim2=2).contiguous() + + # TODO: port this? + raypos, raydir, tminmax = compute_raydirs( + viewpos=camera["campos"], + viewrot=camera["camrot"], + focal=focal, + princpt=camera["princpt"], + pixelcoords=pixel_coords, + volradius=self.volradius, + ) + + rgba = mvpraymarch( + raypos, + raydir, + stepsize=dt, + tminmax=tminmax, + algo=0, + template=prim_rgba.permute(0, 1, 3, 4, 5, 2).contiguous(), + warp=None, + termthresh=self.termthresh, + primtransf=(prim_pos, prim_rot, prim_scale), + fadescale=self.fadescale, + fadeexp=self.fadeexp, + usebvh="fixedorder", + chlast=True, + ) + + rgba = rgba.permute(0, 3, 1, 2) + + preds = { + "rgba_image": rgba, + "pixel_coords": pixel_coords, + } + + return preds + + +def generate_colored_boxes(template, prim_rot, alpha=10000.0, seed=123456): + B = template.shape[0] + output = template.clone() + device = template.device + + lightdir = -3 * th.ones([B, 3], dtype=th.float32, device=device) + lightdir = lightdir / th.norm(lightdir, p=2, dim=1, keepdim=True) + + zz, yy, xx = th.meshgrid( + th.linspace(-1.0, 1.0, template.size(-1), device=device), + th.linspace(-1.0, 1.0, template.size(-1), device=device), + th.linspace(-1.0, 1.0, template.size(-1), device=device), + ) + primnormalx = th.where( + (th.abs(xx) >= th.abs(yy)) & (th.abs(xx) >= th.abs(zz)), + th.sign(xx) * th.ones_like(xx), + th.zeros_like(xx), + ) + primnormaly = th.where( + (th.abs(yy) >= th.abs(xx)) & (th.abs(yy) >= th.abs(zz)), + th.sign(yy) * th.ones_like(xx), + th.zeros_like(xx), + ) + primnormalz = th.where( + (th.abs(zz) >= th.abs(xx)) & (th.abs(zz) >= th.abs(yy)), + th.sign(zz) * th.ones_like(xx), + th.zeros_like(xx), + ) + primnormal = th.stack([primnormalx, -primnormaly, -primnormalz], dim=-1) + primnormal = primnormal / th.sqrt(th.sum(primnormal**2, dim=-1, keepdim=True)) + + output[:, :, 3, :, :, :] = alpha + + np.random.seed(seed) + + for i in range(template.size(1)): + # generating a random color + output[:, i, 0, :, :, :] = np.random.rand() * 255.0 + output[:, i, 1, :, :, :] = np.random.rand() * 255.0 + output[:, i, 2, :, :, :] = np.random.rand() * 255.0 + + # get light direction in local coordinate system? + lightdir0 = lightdir + mult = th.sum( + lightdir0[:, None, None, None, :] * primnormal[np.newaxis], dim=-1 + )[:, np.newaxis, :, :, :].clamp(min=0.2) + output[:, i, :3, :, :, :] *= 1.4 * mult + return output diff --git a/dva/scheduler.py b/dva/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..71fa46fb580b7f38d021727baaa88024c04368fe --- /dev/null +++ b/dva/scheduler.py @@ -0,0 +1,21 @@ +import math +from torch.optim.lr_scheduler import LRScheduler + +class CosineWarmupScheduler(LRScheduler): + def __init__(self, optimizer, warmup_iters: int, max_iters: int, initial_lr: float = 1e-10, last_iter: int = -1): + self.warmup_iters = warmup_iters + self.max_iters = max_iters + self.initial_lr = initial_lr + super().__init__(optimizer, last_iter) + + def get_lr(self): + if self._step_count <= self.warmup_iters: + return [ + self.initial_lr + (base_lr - self.initial_lr) * self._step_count / self.warmup_iters + for base_lr in self.base_lrs] + else: + cos_iter = self._step_count - self.warmup_iters + cos_max_iter = self.max_iters - self.warmup_iters + cos_theta = cos_iter / cos_max_iter * math.pi + cos_lr = [base_lr * (1 + math.cos(cos_theta)) / 2 for base_lr in self.base_lrs] + return cos_lr diff --git a/dva/utils.py b/dva/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0d3adff3e89da88d0944eb0e0d491a9a9f371c1d --- /dev/null +++ b/dva/utils.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np + +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +import cv2 + + +def label_image( + image, + label, + font_scale=1.0, + font_thickness=1, + label_origin=(10, 64), + font_color=(255, 255, 255), + font=cv2.FONT_HERSHEY_SIMPLEX, +): + text_size, baseline = cv2.getTextSize(label, font, font_scale, font_thickness) + image[ + label_origin[1] - text_size[1] : label_origin[1] + baseline, + label_origin[0] : label_origin[0] + text_size[0], + ] = (255 - font_color[0], 255 - font_color[1], 255 - font_color[2]) + cv2.putText( + image, label, label_origin, font, font_scale, font_color, font_thickness + ) + return image + + +def to_device(values, device=None, non_blocking=True): + """Transfer a set of values to the device. + Args: + values: a nested dict/list/tuple of tensors + device: argument to `to()` for the underlying vector + NOTE: + if the device is not specified, using `th.cuda()` + """ + if device is None: + device = th.device("cuda") + + if isinstance(values, dict): + return {k: to_device(v, device=device) for k, v in values.items()} + elif isinstance(values, tuple): + return tuple(to_device(v, device=device) for v in values) + elif isinstance(values, list): + return [to_device(v, device=device) for v in values] + elif isinstance(values, th.Tensor): + return values.to(device, non_blocking=non_blocking) + elif isinstance(values, nn.Module): + return values.to(device) + elif isinstance(values, np.ndarray): + return th.from_numpy(values).to(device) + else: + return values diff --git a/dva/vgg.py b/dva/vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..e27b27bd7761ced1b03cfda9814a3b5bd8994ac1 --- /dev/null +++ b/dva/vgg.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch as th +import torch.nn as nn +from torchvision.models import vgg19 +import torch.nn.functional as F +import logging + +logger = logging.getLogger(__name__) + + +class Vgg19(nn.Module): + def __init__(self, requires_grad=False): + super(Vgg19, self).__init__() + vgg19_network = vgg19(pretrained=True) + # vgg19_network.load_state_dict(state_dict) + vgg_pretrained_features = vgg19_network.features + self.slice1 = nn.Sequential() + self.slice2 = nn.Sequential() + self.slice3 = nn.Sequential() + self.slice4 = nn.Sequential() + self.slice5 = nn.Sequential() + for x in range(2): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(2, 7): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(7, 12): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 21): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(21, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h_relu1 = self.slice1(X) + h_relu2 = self.slice2(h_relu1) + h_relu3 = self.slice3(h_relu2) + h_relu4 = self.slice4(h_relu3) + h_relu5 = self.slice5(h_relu4) + out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] + return out + + +class VGGLossMasked(nn.Module): + def __init__(self, weights=None): + super().__init__() + self.vgg = Vgg19() + if weights is None: + # self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] + self.weights = [20.0, 5.0, 0.9, 0.5, 0.5] + else: + self.weights = weights + + def normalize(self, batch): + mean = batch.new_tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) + std = batch.new_tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) + return ((batch / 255.0).clamp(0.0, 1.0) - mean) / std + + def forward(self, x_rgb, y_rgb, mask): + + x_norm = self.normalize(x_rgb) + y_norm = self.normalize(y_rgb) + + x_vgg = self.vgg(x_norm) + y_vgg = self.vgg(y_norm) + loss = 0 + for i in range(len(x_vgg)): + if isinstance(mask, th.Tensor): + m = F.interpolate( + mask, size=(x_vgg[i].shape[-2], x_vgg[i].shape[-1]), mode="bilinear" + ).detach() + else: + m = mask + + vx = x_vgg[i] * m + vy = y_vgg[i] * m + + loss += self.weights[i] * (vx - vy).abs().mean() + + # logger.info( + # f"loss for {i}, {loss.item()} vx={vx.shape} vy={vy.shape} {vx.max()} {vy.max()}" + # ) + return loss diff --git a/dva/visualize.py b/dva/visualize.py new file mode 100644 index 0000000000000000000000000000000000000000..240d8f30f4de1a24f66691edf33686622d8512ec --- /dev/null +++ b/dva/visualize.py @@ -0,0 +1,478 @@ +import cv2 +import os +import numpy as np +import torch +import imageio +from torchvision.utils import make_grid, save_image +from .ray_marcher import RayMarcher, generate_colored_boxes + +def get_pose_on_orbit(radius, height, angles, world_up=torch.Tensor([0, 1, 0])): + num_points = angles.shape[0] + x = radius * torch.cos(angles) + h = torch.ones((num_points,)) * height + z = radius * torch.sin(angles) + position = torch.stack([x, h, z], dim=-1) + forward = position / torch.norm(position, p=2, dim=-1, keepdim=True) + right = -torch.cross(world_up[None, ...], forward) + right /= torch.norm(right, dim=-1, keepdim=True) + up = torch.cross(forward, right) + up /= torch.norm(up, p=2, dim=-1, keepdim=True) + rotation = torch.stack([right, up, forward], dim=1) + translation = torch.Tensor([0, 0, radius])[None, :, None].repeat(num_points, 1, 1) + return torch.concat([rotation, translation], dim=2) + +def render_mvp_boxes(rm, batch, preds): + with torch.no_grad(): + boxes_rgba = generate_colored_boxes( + preds["prim_rgba"], + preds["prim_rot"], + ) + preds_boxes = rm( + prim_rgba=boxes_rgba, + prim_pos=preds["prim_pos"], + prim_scale=preds["prim_scale"], + prim_rot=preds["prim_rot"], + RT=batch["Rt"], + K=batch["K"], + ) + + return preds_boxes["rgba_image"][:, :3].permute(0, 2, 3, 1) + + +def save_image_summary(path, batch, preds): + rgb = preds["rgb"].detach().permute(0, 3, 1, 2) + # rgb_gt = batch["image"] + rgb_boxes = preds["rgb_boxes"].detach().permute(0, 3, 1, 2) + bs = rgb_boxes.shape[0] + if "folder" in batch and "key" in batch: + obj_list = [] + for bs_idx in range(bs): + tmp_img = rgb_boxes[bs_idx].permute(1, 2, 0).to(torch.uint8).cpu().numpy() + tmp_img = np.ascontiguousarray(tmp_img) + folder = batch['folder'][bs_idx] + key = batch['key'][bs_idx] + obj_list.append("{}/{}\n".format(folder, key)) + cv2.putText(tmp_img, "{}".format(folder), (200, 200), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 0), 2) + cv2.putText(tmp_img, "{}".format(key), (200, 400), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 0), 2) + tmp_img_torch = torch.as_tensor(tmp_img).permute(2, 0, 1).float() + rgb_boxes[bs_idx] = tmp_img_torch + with open(os.path.splitext(path)[0]+".txt", "w") as f: + f.writelines(obj_list) + img = make_grid(torch.cat([rgb, rgb_boxes], dim=2) / 255.0).clip(0.0, 1.0) + save_image(img, path) + + +@torch.no_grad() +def visualize_primsdf_box(image_save_path, model, rm: RayMarcher, device): + # prim_rgba: primitive payload [B, K, 4, S, S, S], + # K - # of primitives, S - primitive size + # prim_pos: locations [B, K, 3] + # prim_rot: rotations [B, K, 3, 3] + # prim_scale: scales [B, K, 3] + # K: intrinsics [B, 3, 3] + # RT: extrinsics [B, 3, 4] + preds = {} + batch = {} + prim_alpha = model.sdf2alpha(model.feat_geo).reshape(1, model.num_prims, 1, model.prim_shape, model.prim_shape, model.prim_shape) * 255 + prim_rgb = model.feat_tex.reshape(1, model.num_prims, 3, model.prim_shape, model.prim_shape, model.prim_shape) * 255 + preds['prim_rgba'] = torch.concat([prim_rgb, prim_alpha], dim=2) + preds['prim_pos'] = model.pos.reshape(1, model.num_prims, 3) * rm.volradius + preds['prim_rot'] = torch.eye(3).to(preds['prim_pos'])[None, None, ...].repeat(1, model.num_prims, 1, 1) + preds['prim_scale'] = (1 / model.scale.reshape(1, model.num_prims, 1).repeat(1, 1, 3)) + batch['Rt'] = torch.Tensor([ + [ + 1.0, + 0.0, + 0.0, + 0.0 * rm.volradius + ], + [ + 0.0, + -1.0, + 0.0, + 0.0 * rm.volradius + ], + [ + 0.0, + 0.0, + -1.0, + 5 * rm.volradius + ] + ]).to(device)[None, ...] + batch['K'] = torch.Tensor([ + [ + 2084.9526697685183, + 0.0, + 512.0 + ], + [ + 0.0, + 2084.9526697685183, + 512.0 + ], + [ + 0.0, + 0.0, + 1.0 + ]]).to(device)[None, ...] + ratio_h = rm.image_height / 1024. + ratio_w = rm.image_width / 1024. + batch['K'][:, 0:1, :] *= ratio_h + batch['K'][:, 1:2, :] *= ratio_w + # raymarcher is in mm + rm_preds = rm( + prim_rgba=preds["prim_rgba"], + prim_pos=preds["prim_pos"], + prim_scale=preds["prim_scale"], + prim_rot=preds["prim_rot"], + RT=batch["Rt"], + K=batch["K"], + ) + rgba = rm_preds["rgba_image"].permute(0, 2, 3, 1) + preds.update(alpha=rgba[..., -1].contiguous(), rgb=rgba[..., :3].contiguous()) + with torch.no_grad(): + preds["rgb_boxes"] = render_mvp_boxes(rm, batch, preds) + save_image_summary(image_save_path, batch, preds) + +@torch.no_grad() +def render_primsdf(image_save_path, model, rm, device): + preds = {} + batch = {} + preds['prim_pos'] = model.pos.reshape(1, model.num_prims, 3) * rm.volradius + preds['prim_rot'] = torch.eye(3).to(preds['prim_pos'])[None, None, ...].repeat(1, model.num_prims, 1, 1) + preds['prim_scale'] = (1 / model.scale.reshape(1, model.num_prims, 1).repeat(1, 1, 3)) + batch['Rt'] = torch.Tensor([ + [ + 1.0, + 0.0, + 0.0, + 0.0 * rm.volradius + ], + [ + 0.0, + -1.0, + 0.0, + 0.0 * rm.volradius + ], + [ + 0.0, + 0.0, + -1.0, + 5 * rm.volradius + ] + ]).to(device)[None, ...] + batch['K'] = torch.Tensor([ + [ + 2084.9526697685183, + 0.0, + 512.0 + ], + [ + 0.0, + 2084.9526697685183, + 512.0 + ], + [ + 0.0, + 0.0, + 1.0 + ]]).to(device)[None, ...] + ratio_h = rm.image_height / 1024. + ratio_w = rm.image_width / 1024. + batch['K'][:, 0:1, :] *= ratio_h + batch['K'][:, 1:2, :] *= ratio_w + # test rendering + all_sampled_sdf = [] + all_sampled_tex = [] + for i in range(model.prim_shape ** 3): + with torch.no_grad(): + model_prediction = model(model.sdf_sampled_point[:, i, :].to(device)) + sampled_sdf = model_prediction['sdf'] + sampled_rgb = model_prediction['tex'] + all_sampled_sdf.append(sampled_sdf) + all_sampled_tex.append(sampled_rgb) + sampled_sdf = torch.stack(all_sampled_sdf, dim=1) + sampled_tex = torch.stack(all_sampled_tex, dim=1).permute(0, 2, 1).reshape(1, model.num_prims, 3, model.prim_shape, model.prim_shape, model.prim_shape) * 255 + prim_rgb = sampled_tex + prim_alpha = model.sdf2alpha(sampled_sdf).reshape(1, model.num_prims, 1, model.prim_shape, model.prim_shape, model.prim_shape) * 255 + preds['prim_rgba'] = torch.concat([prim_rgb, prim_alpha], dim=2) + rm_preds = rm( + prim_rgba=preds["prim_rgba"], + prim_pos=preds["prim_pos"], + prim_scale=preds["prim_scale"], + prim_rot=preds["prim_rot"], + RT=batch["Rt"], + K=batch["K"], + ) + + rgba = rm_preds["rgba_image"].permute(0, 2, 3, 1) + preds.update(alpha=rgba[..., -1].contiguous(), rgb=rgba[..., :3].contiguous()) + with torch.no_grad(): + preds["rgb_boxes"] = render_mvp_boxes(rm, batch, preds) + save_image_summary(image_save_path, batch, preds) + +@torch.no_grad() +def visualize_primvolume(image_save_path, batch, prim_volume, rm: RayMarcher, device): + # prim_volume - [B, nprims, 4+6*8^3] + def sdf2alpha(sdf): + return torch.exp(-(sdf / 0.005) ** 2) + preds = {} + prim_shape = int(np.round(((prim_volume.shape[2] - 4) / 6) ** (1/3))) + num_prims = prim_volume.shape[1] + bs = prim_volume.shape[0] + geo_start_index = 4 + geo_end_index = geo_start_index + prim_shape ** 3 # non-inclusive + tex_start_index = geo_end_index + tex_end_index = tex_start_index + prim_shape ** 3 * 3 # non-inclusive + mat_start_index = tex_end_index + mat_end_index = mat_start_index + prim_shape ** 3 * 2 + + feat_geo = prim_volume[:, :, geo_start_index: geo_end_index] + feat_tex = prim_volume[:, :, tex_start_index: tex_end_index] + prim_alpha = sdf2alpha(feat_geo).reshape(bs, num_prims, 1, prim_shape, prim_shape, prim_shape) * 255 + prim_rgb = feat_tex.reshape(bs, num_prims, 3, prim_shape, prim_shape, prim_shape) * 255 + preds['prim_rgba'] = torch.concat([prim_rgb, prim_alpha], dim=2) + pos = prim_volume[:, :, 1:4] + scale = prim_volume[:, :, 0:1] + preds['prim_pos'] = pos.reshape(bs, num_prims, 3) * rm.volradius + preds['prim_rot'] = torch.eye(3).to(preds['prim_pos'])[None, None, ...].repeat(bs, num_prims, 1, 1) + preds['prim_scale'] = (1 / scale.reshape(bs, num_prims, 1).repeat(1, 1, 3)) + batch['Rt'] = torch.Tensor([ + [ + 1.0, + 0.0, + 0.0, + 0.0 * rm.volradius + ], + [ + 0.0, + -1.0, + 0.0, + 0.0 * rm.volradius + ], + [ + 0.0, + 0.0, + -1.0, + 5 * rm.volradius + ] + ]).to(device)[None, ...].repeat(bs, 1, 1) + batch['K'] = torch.Tensor([ + [ + 2084.9526697685183, + 0.0, + 512.0 + ], + [ + 0.0, + 2084.9526697685183, + 512.0 + ], + [ + 0.0, + 0.0, + 1.0 + ]]).to(device)[None, ...].repeat(bs, 1, 1) + ratio_h = rm.image_height / 1024. + ratio_w = rm.image_width / 1024. + batch['K'][:, 0:1, :] *= ratio_h + batch['K'][:, 1:2, :] *= ratio_w + # raymarcher is in mm + rm_preds = rm( + prim_rgba=preds["prim_rgba"], + prim_pos=preds["prim_pos"], + prim_scale=preds["prim_scale"], + prim_rot=preds["prim_rot"], + RT=batch["Rt"], + K=batch["K"], + ) + rgba = rm_preds["rgba_image"].permute(0, 2, 3, 1) + preds.update(alpha=rgba[..., -1].contiguous(), rgb=rgba[..., :3].contiguous()) + with torch.no_grad(): + preds["rgb_boxes"] = render_mvp_boxes(rm, batch, preds) + save_image_summary(image_save_path, batch, preds) + +@torch.no_grad() +def visualize_multiview_primvolume(image_save_path, batch, prim_volume, view_counts, rm: RayMarcher, device): + # prim_volume - [B, nprims, 4+6*8^3] + view_angles = torch.linspace(0.5, 2.5, view_counts + 1) * torch.pi + view_angles = view_angles[:-1] + def sdf2alpha(sdf): + return torch.exp(-(sdf / 0.005) ** 2) + preds = {} + prim_shape = int(np.round(((prim_volume.shape[2] - 4) / 6) ** (1/3))) + num_prims = prim_volume.shape[1] + bs = prim_volume.shape[0] + geo_start_index = 4 + geo_end_index = geo_start_index + prim_shape ** 3 # non-inclusive + tex_start_index = geo_end_index + tex_end_index = tex_start_index + prim_shape ** 3 * 3 # non-inclusive + mat_start_index = tex_end_index + mat_end_index = mat_start_index + prim_shape ** 3 * 2 + + feat_geo = prim_volume[:, :, geo_start_index: geo_end_index] + feat_tex = prim_volume[:, :, tex_start_index: tex_end_index] + prim_alpha = sdf2alpha(feat_geo).reshape(bs, num_prims, 1, prim_shape, prim_shape, prim_shape) * 255 + prim_rgb = feat_tex.reshape(bs, num_prims, 3, prim_shape, prim_shape, prim_shape) * 255 + preds['prim_rgba'] = torch.concat([prim_rgb, prim_alpha], dim=2) + pos = prim_volume[:, :, 1:4] + scale = prim_volume[:, :, 0:1] + preds['prim_pos'] = pos.reshape(bs, num_prims, 3) * rm.volradius + preds['prim_rot'] = torch.eye(3).to(preds['prim_pos'])[None, None, ...].repeat(bs, num_prims, 1, 1) + preds['prim_scale'] = (1 / scale.reshape(bs, num_prims, 1).repeat(1, 1, 3)) + batch['K'] = torch.Tensor([ + [ + 2084.9526697685183, + 0.0, + 512.0 + ], + [ + 0.0, + 2084.9526697685183, + 512.0 + ], + [ + 0.0, + 0.0, + 1.0 + ]]).to(device)[None, ...].repeat(bs, 1, 1) + ratio_h = rm.image_height / 1024. + ratio_w = rm.image_width / 1024. + batch['K'][:, 0:1, :] *= ratio_h + batch['K'][:, 1:2, :] *= ratio_w + + final_preds = {} + final_preds['rgb'] = [] + final_preds['rgb_boxes'] = [] + for view_ang in view_angles: + bs_view_ang = view_ang.repeat(bs,) + batch['Rt'] = get_pose_on_orbit(radius=5*rm.volradius, height=0, angles=bs_view_ang).to(prim_volume) + # raymarcher is in mm + rm_preds = rm( + prim_rgba=preds["prim_rgba"], + prim_pos=preds["prim_pos"], + prim_scale=preds["prim_scale"], + prim_rot=preds["prim_rot"], + RT=batch["Rt"], + K=batch["K"], + ) + rgba = rm_preds["rgba_image"].permute(0, 2, 3, 1) + preds.update(alpha=rgba[..., -1].contiguous(), rgb=rgba[..., :3].contiguous()) + with torch.no_grad(): + preds["rgb_boxes"] = render_mvp_boxes(rm, batch, preds) + final_preds['rgb'].append(preds['rgb']) + final_preds['rgb_boxes'].append(preds['rgb_boxes']) + final_preds['rgb'] = torch.concat(final_preds['rgb'], dim=0) + final_preds['rgb_boxes'] = torch.concat(final_preds['rgb_boxes'], dim=0) + save_image_summary(image_save_path, batch, final_preds) + + +@torch.no_grad() +def visualize_video_primvolume(video_save_folder, batch, prim_volume, view_counts, rm: RayMarcher, device): + # prim_volume - [B, nprims, 4+6*8^3] + view_angles = torch.linspace(1.5, 3.5, view_counts + 1) * torch.pi + def sdf2alpha(sdf): + return torch.exp(-(sdf / 0.005) ** 2) + preds = {} + prim_shape = int(np.round(((prim_volume.shape[2] - 4) / 6) ** (1/3))) + num_prims = prim_volume.shape[1] + bs = prim_volume.shape[0] + geo_start_index = 4 + geo_end_index = geo_start_index + prim_shape ** 3 # non-inclusive + tex_start_index = geo_end_index + tex_end_index = tex_start_index + prim_shape ** 3 * 3 # non-inclusive + mat_start_index = tex_end_index + mat_end_index = mat_start_index + prim_shape ** 3 * 2 + + feat_geo = prim_volume[:, :, geo_start_index: geo_end_index] + feat_tex = prim_volume[:, :, tex_start_index: tex_end_index] + feat_mat = prim_volume[:, :, mat_start_index: mat_end_index] + prim_alpha = sdf2alpha(feat_geo).reshape(bs, num_prims, 1, prim_shape, prim_shape, prim_shape) * 255 + prim_rgb = feat_tex.reshape(bs, num_prims, 3, prim_shape, prim_shape, prim_shape) * 255 + prim_mat = feat_mat.reshape(bs, num_prims, 2, prim_shape, prim_shape, prim_shape) * 255 + dummy_prim = torch.zeros_like(prim_mat[:, :, 0:1, ...]) + prim_mat = torch.concat([dummy_prim, prim_mat], dim=2) + preds['prim_rgba'] = torch.concat([prim_rgb, prim_alpha], dim=2) + preds['prim_mata'] = torch.concat([prim_mat, prim_alpha], dim=2) + pos = prim_volume[:, :, 1:4] + scale = prim_volume[:, :, 0:1] + preds['prim_pos'] = pos.reshape(bs, num_prims, 3) * rm.volradius + preds['prim_rot'] = torch.eye(3).to(preds['prim_pos'])[None, None, ...].repeat(bs, num_prims, 1, 1) + preds['prim_scale'] = (1 / scale.reshape(bs, num_prims, 1).repeat(1, 1, 3)) + batch['K'] = torch.Tensor([ + [ + 2084.9526697685183, + 0.0, + 512.0 + ], + [ + 0.0, + 2084.9526697685183, + 512.0 + ], + [ + 0.0, + 0.0, + 1.0 + ]]).to(device)[None, ...].repeat(bs, 1, 1) + ratio_h = rm.image_height / 1024. + ratio_w = rm.image_width / 1024. + batch['K'][:, 0:1, :] *= ratio_h + batch['K'][:, 1:2, :] *= ratio_w + + final_preds = {} + final_preds['rgb'] = [] + final_preds['rgb_boxes'] = [] + final_preds['mat_rgb'] = [] + for view_ang in view_angles: + bs_view_ang = view_ang.repeat(bs,) + batch['Rt'] = get_pose_on_orbit(radius=5*rm.volradius, height=0, angles=bs_view_ang).to(prim_volume) + # raymarcher is in mm + rm_preds = rm( + prim_rgba=preds["prim_rgba"], + prim_pos=preds["prim_pos"], + prim_scale=preds["prim_scale"], + prim_rot=preds["prim_rot"], + RT=batch["Rt"], + K=batch["K"], + ) + rgba = rm_preds["rgba_image"].permute(0, 2, 3, 1) + preds.update(alpha=rgba[..., -1].contiguous(), rgb=rgba[..., :3].contiguous()) + with torch.no_grad(): + preds["rgb_boxes"] = render_mvp_boxes(rm, batch, preds) + rm_preds = rm( + prim_rgba=preds["prim_mata"], + prim_pos=preds["prim_pos"], + prim_scale=preds["prim_scale"], + prim_rot=preds["prim_rot"], + RT=batch["Rt"], + K=batch["K"], + ) + mat_rgba = rm_preds["rgba_image"].permute(0, 2, 3, 1) + preds.update(mat_rgb=mat_rgba[..., :3].contiguous()) + final_preds['rgb'].append(preds['rgb']) + final_preds['rgb_boxes'].append(preds['rgb_boxes']) + final_preds['mat_rgb'].append(preds['mat_rgb']) + + assert len(final_preds['rgb']) == len(final_preds['rgb_boxes']) + final_preds['rgb'] = torch.concat(final_preds['rgb'], dim=0) + final_preds['rgb_boxes'] = torch.concat(final_preds['rgb_boxes'], dim=0) + final_preds['mat_rgb'] = torch.concat(final_preds['mat_rgb'], dim=0) + total_num_frames = final_preds['rgb'].shape[0] + rgb_video = os.path.join(video_save_folder, 'rgb.mp4') + rgb_video_out = imageio.get_writer(rgb_video, fps=20) + prim_video = os.path.join(video_save_folder, 'prim.mp4') + prim_video_out = imageio.get_writer(prim_video, fps=20) + mat_video = os.path.join(video_save_folder, 'mat.mp4') + mat_video_out = imageio.get_writer(mat_video, fps=20) + + rgb_np = np.clip(final_preds['rgb'].detach().cpu().numpy(), 0, 255).astype(np.uint8) + prim_np = np.clip(final_preds['rgb_boxes'].detach().cpu().numpy(), 0, 255).astype(np.uint8) + mat_np = np.clip(final_preds['mat_rgb'].detach().cpu().numpy(), 0, 255).astype(np.uint8) + for fidx in range(total_num_frames): + rgb_video_out.append_data(rgb_np[fidx]) + prim_video_out.append_data(prim_np[fidx]) + mat_video_out.append_data(mat_np[fidx]) + rgb_video_out.close() + prim_video_out.close() + mat_video_out.close() \ No newline at end of file diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..d684980ab3188c1df6c4a09c34cf122e5f246fdc --- /dev/null +++ b/inference.py @@ -0,0 +1,363 @@ +import os +import sys +import io + +import torch +import numpy as np +from omegaconf import OmegaConf +import PIL.Image +from PIL import Image +import rembg + +from dva.ray_marcher import RayMarcher +from dva.io import load_from_config +from dva.utils import to_device +from dva.visualize import visualize_primvolume, visualize_video_primvolume +from models.diffusion import create_diffusion +import logging +from tqdm import tqdm + +import mcubes +import xatlas +import nvdiffrast.torch as dr +import cv2 +from scipy.ndimage import binary_dilation, binary_erosion +from sklearn.neighbors import NearestNeighbors +from utils.meshutils import clean_mesh, decimate_mesh +from utils.mesh import Mesh +logger = logging.getLogger("inference.py") + + +def remove_background(image: PIL.Image.Image, + rembg_session = None, + force: bool = False, + **rembg_kwargs, +) -> PIL.Image.Image: + do_remove = True + if image.mode == "RGBA" and image.getextrema()[3][0] < 255: + do_remove = False + do_remove = do_remove or force + if do_remove: + image = rembg.remove(image, session=rembg_session, **rembg_kwargs) + return image + +def resize_foreground( + image: PIL.Image.Image, + ratio: float, +) -> PIL.Image.Image: + image = np.array(image) + assert image.shape[-1] == 4 + alpha = np.where(image[..., 3] > 0) + y1, y2, x1, x2 = ( + alpha[0].min(), + alpha[0].max(), + alpha[1].min(), + alpha[1].max(), + ) + # crop the foreground + fg = image[y1:y2, x1:x2] + # pad to square + size = max(fg.shape[0], fg.shape[1]) + ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2 + ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0 + new_image = np.pad( + fg, + ((ph0, ph1), (pw0, pw1), (0, 0)), + mode="constant", + constant_values=((0, 0), (0, 0), (0, 0)), + ) + + # compute padding according to the ratio + new_size = int(new_image.shape[0] / ratio) + # pad to size, double side + ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2 + ph1, pw1 = new_size - size - ph0, new_size - size - pw0 + new_image = np.pad( + new_image, + ((ph0, ph1), (pw0, pw1), (0, 0)), + mode="constant", + constant_values=((0, 0), (0, 0), (0, 0)), + ) + new_image = PIL.Image.fromarray(new_image) + return new_image + +def extract_texmesh(args, model, output_path, device): + # Prepare directory + ins_dir = output_path + + # Get SDFs + with torch.no_grad(): + xx = torch.linspace(-1, 1, args.mc_resolution, device=device) + pts = torch.stack(torch.meshgrid(xx, xx, xx, indexing='ij'), dim=-1).reshape(-1,3) + chunks = torch.split(pts, args.batch_size) + dists = [] + for chunk_pts in tqdm(chunks): + preds = model(chunk_pts) + dists.append(preds['sdf'].detach()) + dists = torch.cat(dists, dim=0) + grid = dists.reshape(args.mc_resolution, args.mc_resolution, args.mc_resolution) + + # Meshify + vertices, triangles = mcubes.marching_cubes(grid.cpu().numpy(), 0.0) + + # Resize + recenter + b_min_np = np.array([-1., -1., -1.]) + b_max_np = np.array([ 1., 1., 1.]) + vertices = vertices / (args.mc_resolution - 1.0) * (b_max_np - b_min_np) + b_min_np + + vertices, triangles = clean_mesh(vertices, triangles, min_f=8, min_d=5, repair=True, remesh=False) + + if args.decimate > 0 and triangles.shape[0] > args.decimate: + vertices, triangles = decimate_mesh(vertices, triangles, args.decimate, remesh=args.remesh) + + h0 = 1024 + w0 = 1024 + ssaa = 1 + fp16 = True + glctx = dr.RasterizeGLContext(output_db=False) + v_np = vertices.astype(np.float32) + f_np = triangles.astype(np.int64) + v = torch.from_numpy(vertices).float().contiguous().to(device) + f = torch.from_numpy(triangles.astype(np.int64)).int().contiguous().to(device) + print(f'[INFO] running xatlas to unwrap UVs for mesh: v={v_np.shape} f={f_np.shape}') + # unwrap uv in contracted space + atlas = xatlas.Atlas() + atlas.add_mesh(v_np, f_np) + chart_options = xatlas.ChartOptions() + chart_options.max_iterations = 0 # disable merge_chart for faster unwrap... + pack_options = xatlas.PackOptions() + # pack_options.blockAlign = True + # pack_options.bruteForce = False + atlas.generate(chart_options=chart_options, pack_options=pack_options) + vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2] + + vt = torch.from_numpy(vt_np.astype(np.float32)).float().contiguous().to(device) + ft = torch.from_numpy(ft_np.astype(np.int64)).int().contiguous().to(device) + uv = vt * 2.0 - 1.0 # uvs to range [-1, 1] + uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4] + + if ssaa > 1: + h = int(h0 * ssaa) + w = int(w0 * ssaa) + else: + h, w = h0, w0 + + rast, _ = dr.rasterize(glctx, uv.unsqueeze(0), ft, (h, w)) # [1, h, w, 4] + xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, h, w, 3] + mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f) # [1, h, w, 1] + # masked query + xyzs = xyzs.view(-1, 3) + mask = (mask > 0).view(-1) + feats = torch.zeros(h * w, 6, device=device, dtype=torch.float32) + + if mask.any(): + xyzs = xyzs[mask] # [M, 3] + # batched inference to avoid OOM + all_feats = [] + head = 0 + chunk_size = args.batch_size + while head < xyzs.shape[0]: + tail = min(head + chunk_size, xyzs.shape[0]) + with torch.cuda.amp.autocast(enabled=fp16): + preds = model(xyzs[head:tail]) + # [R, G, B, NA, roughness, metallic] + all_feats.append(torch.concat([preds['tex'].float(), torch.zeros_like(preds['tex'])[..., 0:1].float(), preds['mat'].float()], dim=-1)) + head += chunk_size + feats[mask] = torch.cat(all_feats, dim=0) + feats = feats.view(h, w, -1) # 6 channels + mask = mask.view(h, w) + # quantize [0.0, 1.0] to [0, 255] + feats = feats.cpu().numpy() + feats = (feats * 255) + + ### NN search as a queer antialiasing ... + mask = mask.cpu().numpy() + inpaint_region = binary_dilation(mask, iterations=32) # pad width + inpaint_region[mask] = 0 + search_region = mask.copy() + not_search_region = binary_erosion(search_region, iterations=3) + search_region[not_search_region] = 0 + search_coords = np.stack(np.nonzero(search_region), axis=-1) + inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1) + knn = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(search_coords) + _, indices = knn.kneighbors(inpaint_coords) + feats[tuple(inpaint_coords.T)] = feats[tuple(search_coords[indices[:, 0]].T)] + # do ssaa after the NN search, in numpy + feats0 = cv2.cvtColor(feats[..., :3].astype(np.uint8), cv2.COLOR_RGB2BGR) # albedo + feats1 = cv2.cvtColor(feats[..., 3:].astype(np.uint8), cv2.COLOR_RGB2BGR) # visibility features + if ssaa > 1: + feats0 = cv2.resize(feats0, (w0, h0), interpolation=cv2.INTER_LINEAR) + feats1 = cv2.resize(feats1, (w0, h0), interpolation=cv2.INTER_LINEAR) + + cv2.imwrite(os.path.join(ins_dir, f'texture.jpg'), feats0) + cv2.imwrite(os.path.join(ins_dir, f'roughness_metallic.jpg'), feats1) + + target_mesh = Mesh(v=torch.from_numpy(v_np).contiguous(), f=torch.from_numpy(f_np).contiguous(), ft=ft.contiguous(), vt=torch.from_numpy(vt_np).contiguous(), albedo=torch.from_numpy(feats[..., :3]) / 255, metallicRoughness=torch.from_numpy(feats[..., 3:]) / 255) + target_mesh.write(os.path.join(ins_dir, f'pbr_mesh.glb')) + +def main(config): + logging.basicConfig(level=logging.INFO) + ddim_steps = config.inference.ddim + if ddim_steps > 0: + use_ddim = True + else: + use_ddim = False + cfg_scale = config.inference.get("cfg", 0.0) + + inference_dir = f"{config.output_dir}/inference_folder" + os.makedirs(inference_dir, exist_ok=True) + + amp = False + precision = config.inference.get("precision", 'fp16') + if precision == 'tf32': + precision_dtype = torch.float32 + elif precision == 'fp16': + amp = True + precision_dtype = torch.float16 + else: + raise NotImplementedError("{} precision is not supported".format(precision)) + + device = torch.device(f"cuda:{0}") + seed = config.inference.seed + torch.manual_seed(seed) + torch.cuda.set_device(device) + + model = load_from_config(config.model.generator) + vae = load_from_config(config.model.vae) + conditioner = load_from_config(config.model.conditioner) + vae_state_dict = torch.load(config.model.vae_checkpoint_path, map_location='cpu') + vae.load_state_dict(vae_state_dict['model_state_dict']) + + if config.checkpoint_path: + state_dict = torch.load(config.checkpoint_path, map_location='cpu') + model.load_state_dict(state_dict['ema']) + vae = vae.to(device) + conditioner = conditioner.to(device) + model = model.to(device) + config.diffusion.pop("timestep_respacing") + if use_ddim: + respacing = "ddim{}".format(ddim_steps) + else: + respacing = "" + diffusion = create_diffusion(timestep_respacing=respacing, **config.diffusion) # default: 1000 steps, linear noise schedule + if use_ddim: + sample_fn = diffusion.ddim_sample_loop_progressive + else: + sample_fn = diffusion.p_sample_loop_progressive + + if cfg_scale > 0: + fwd_fn = model.forward_with_cfg + else: + fwd_fn = model.forward + + rm = RayMarcher( + config.image_height, + config.image_width, + **config.rm, + ).to(device) + + perchannel_norm = False + if "latent_mean" in config.model: + latent_mean = torch.Tensor(config.model.latent_mean)[None, None, :].to(device) + latent_std = torch.Tensor(config.model.latent_std)[None, None, :].to(device) + assert latent_mean.shape[-1] == config.model.generator.in_channels + perchannel_norm = True + + model.eval() + examples_dir = config.inference.input_dir + img_list = os.listdir(examples_dir) + rembg_session = rembg.new_session() + logger.info(f"Starting Inference...") + for img_path in img_list: + full_img_path = os.path.join(examples_dir, img_path) + img_name = img_path[:-4] + current_output_dir = os.path.join(inference_dir, img_name) + os.makedirs(current_output_dir, exist_ok=True) + input_image = Image.open(full_img_path) + input_image = remove_background(input_image, rembg_session) + input_image = resize_foreground(input_image, 0.85) + raw_image = np.array(input_image) + mask = (raw_image[..., -1][..., None] > 0) * 1 + raw_image = raw_image[..., :3] * mask + input_cond = torch.from_numpy(np.array(raw_image)[None, ...]).to(device) + with torch.no_grad(): + latent = torch.randn(1, config.model.num_prims, 1, 4, 4, 4) + batch = {} + inf_bs = 1 + inf_x = torch.randn(inf_bs, config.model.num_prims, 68).to(device) + y = conditioner.encoder(input_cond) + model_kwargs = dict(y=y[:inf_bs, ...], precision_dtype=precision_dtype, enable_amp=amp) + if cfg_scale > 0: + model_kwargs['cfg_scale'] = cfg_scale + sampled_count = -1 + for samples in sample_fn(fwd_fn, inf_x.shape, inf_x, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device + ): + sampled_count += 1 + if not (sampled_count % 10 == 0 or sampled_count == diffusion.num_timesteps - 1): + continue + else: + recon_param = samples["sample"].reshape(inf_bs, config.model.num_prims, -1) + if perchannel_norm: + recon_param = recon_param / config.model.latent_nf * latent_std + latent_mean + recon_srt_param = recon_param[:, :, 0:4] + recon_feat_param = recon_param[:, :, 4:] # [8, 2048, 64] + recon_feat_param_list = [] + # one-by-one to avoid oom + for inf_bidx in range(inf_bs): + if not perchannel_norm: + decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:]) / config.model.latent_nf) + else: + decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:])) + recon_feat_param_list.append(decoded.detach()) + recon_feat_param = torch.concat(recon_feat_param_list, dim=0) + # invert normalization + if not perchannel_norm: + recon_srt_param[:, :, 0:1] = (recon_srt_param[:, :, 0:1] / 10) + 0.05 + recon_feat_param[:, 0:1, ...] /= 5. + recon_feat_param[:, 1:, ...] = (recon_feat_param[:, 1:, ...] + 1) / 2. + recon_feat_param = recon_feat_param.reshape(inf_bs, config.model.num_prims, -1) + recon_param = torch.concat([recon_srt_param, recon_feat_param], dim=-1) + visualize_primvolume("{}/dstep{:04d}_recon.jpg".format(current_output_dir, sampled_count), batch, recon_param, rm, device) + visualize_video_primvolume(current_output_dir, batch, recon_param, 60, rm, device) + prim_params = {'srt_param': recon_srt_param[0].detach().cpu(), 'feat_param': recon_feat_param[0].detach().cpu()} + torch.save({'model_state_dict': prim_params}, "{}/denoised.pt".format(current_output_dir)) + + if config.inference.export_glb: + logger.info(f"Starting GLB Mesh Extraction...") + config.model.pop("vae") + config.model.pop("vae_checkpoint_path") + config.model.pop("conditioner") + config.model.pop("generator") + config.model.pop("latent_nf") + config.model.pop("latent_mean") + config.model.pop("latent_std") + model_primx = load_from_config(config.model) + for img_path in img_list: + img_name = img_path[:-4] + output_path = os.path.join(inference_dir, img_name) + denoise_param_path = os.path.join(inference_dir, img_name, 'denoised.pt') + ckpt_weight = torch.load(denoise_param_path, map_location='cpu')['model_state_dict'] + model_primx.load_state_dict(ckpt_weight) + model_primx.to(device) + model_primx.eval() + with torch.no_grad(): + model_primx.srt_param[:, 1:4] *= 0.85 + extract_texmesh(config.inference, model_primx, output_path, device) + +if __name__ == "__main__": + torch.backends.cudnn.benchmark = True + # manually enable tf32 to get speedup on A100 GPUs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + os.environ["CC"] = "/mnt/lustre/share/gcc/gcc-8.5.0/bin/gcc" + os.environ["CPP"] = "/mnt/lustre/share/gcc/gcc-8.5.0/bin/g++" + os.environ["CXX"] = "/mnt/lustre/share/gcc/gcc-8.5.0/bin/g++" + # set config + config = OmegaConf.load(str(sys.argv[1])) + config_cli = OmegaConf.from_cli(args_list=sys.argv[2:]) + if config_cli: + logger.info("overriding with following values from args:") + logger.info(OmegaConf.to_yaml(config_cli)) + config = OmegaConf.merge(config, config_cli) + + main(config) diff --git a/install.sh b/install.sh new file mode 100644 index 0000000000000000000000000000000000000000..cb768e2d2343e3cc6ff838ed2dab8301efe49acf --- /dev/null +++ b/install.sh @@ -0,0 +1,6 @@ +CURRENT=$(pwd) +cd dva/mvp/extensions/mvpraymarch +make -j4 +cd ../utils +make -j4 +cd CURRENT \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/attention.py b/models/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..ed91ab1e53a1e09046563e8f8c0fd58c98f151ea --- /dev/null +++ b/models/attention.py @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import os +import warnings + +import torch +from torch import nn +from torch.utils.checkpoint import checkpoint + +from xformers.ops import memory_efficient_attention, unbind + + +class MemEffAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + gradient_checkpointing: bool = False, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + self.gradient_checkpointing = gradient_checkpointing + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor: + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, x, attn_bias, use_reentrant=False) + else: + return self._forward(x, attn_bias) + + def _forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffCrossAttention(nn.Module): + def __init__( + self, + dim: int, + dim_q: int, + dim_k: int, + dim_v: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + gradient_checkpointing: bool = False, + ) -> None: + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + self.gradient_checkpointing = gradient_checkpointing + + self.to_q = nn.Linear(dim_q, dim, bias=qkv_bias) + self.to_k = nn.Linear(dim_k, dim, bias=qkv_bias) + self.to_v = nn.Linear(dim_v, dim, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_bias=None) -> torch.Tensor: + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, q, k, v, attn_bias, use_reentrant=False) + else: + return self._forward(q, k, v, attn_bias) + + def _forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_bias=None) -> torch.Tensor: + # q: [B, N, Cq] + # k: [B, M, Ck] + # v: [B, M, Cv] + # return: [B, N, C] + + B, N, _ = q.shape + M = k.shape[1] + + q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads) # [B, N, nh, C/nh] + k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh] + v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh] + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape(B, N, -1) + + x = self.proj(x) + x = self.proj_drop(x) + return x \ No newline at end of file diff --git a/models/conditioner/dinov2/__init__.py b/models/conditioner/dinov2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/conditioner/dinov2/hub/__init__.py b/models/conditioner/dinov2/hub/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/models/conditioner/dinov2/hub/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/models/conditioner/dinov2/hub/backbones.py b/models/conditioner/dinov2/hub/backbones.py new file mode 100644 index 0000000000000000000000000000000000000000..4672ad979ff39e030fc9bad4c5a4c075412c3947 --- /dev/null +++ b/models/conditioner/dinov2/hub/backbones.py @@ -0,0 +1,166 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Union + +import torch + +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name + + +class Weights(Enum): + LVD142M = "LVD142M" + + +def _make_dinov2_model( + *, + arch_name: str = "vit_large", + img_size: int = 518, + patch_size: int = 14, + init_values: float = 1.0, + ffn_layer: str = "mlp", + block_chunks: int = 0, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD142M, + **kwargs, +): + from ..models import vision_transformer as vits + + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + model_base_name = _make_dinov2_model_name(arch_name, patch_size) + vit_kwargs = dict( + img_size=img_size, + patch_size=patch_size, + init_values=init_values, + ffn_layer=ffn_layer, + block_chunks=block_chunks, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + ) + vit_kwargs.update(**kwargs) + model = vits.__dict__[arch_name](**vit_kwargs) + + if pretrained: + model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + + state_dict = {k: v for k, v in state_dict.items() if 'mask_token' not in k} # DDP concern + if vit_kwargs.get("modulation_dim") is not None: + state_dict = { + k.replace('norm1', 'norm1.norm').replace('norm2', 'norm2.norm'): v + for k, v in state_dict.items() + } + model.load_state_dict(state_dict, strict=False) + else: + model.load_state_dict(state_dict, strict=True) + # ******************************************************** + + return model + + +def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + **kwargs, + ) + + +def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_small", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_base", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_large", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) diff --git a/models/conditioner/dinov2/hub/classifiers.py b/models/conditioner/dinov2/hub/classifiers.py new file mode 100644 index 0000000000000000000000000000000000000000..3f0841efa80ab3d564cd320d61da254af182606b --- /dev/null +++ b/models/conditioner/dinov2/hub/classifiers.py @@ -0,0 +1,268 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Union + +import torch +import torch.nn as nn + +from .backbones import _make_dinov2_model +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name + + +class Weights(Enum): + IMAGENET1K = "IMAGENET1K" + + +def _make_dinov2_linear_classification_head( + *, + arch_name: str = "vit_large", + patch_size: int = 14, + embed_dim: int = 1024, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + num_register_tokens: int = 0, + **kwargs, +): + if layers not in (1, 4): + raise AssertionError(f"Unsupported number of layers: {layers}") + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + linear_head = nn.Linear((1 + layers) * embed_dim, 1_000) + + if pretrained: + model_base_name = _make_dinov2_model_name(arch_name, patch_size) + model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) + layers_str = str(layers) if layers == 4 else "" + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_linear{layers_str}_head.pth" + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + linear_head.load_state_dict(state_dict, strict=True) + + return linear_head + + +class _LinearClassifierWrapper(nn.Module): + def __init__(self, *, backbone: nn.Module, linear_head: nn.Module, layers: int = 4): + super().__init__() + self.backbone = backbone + self.linear_head = linear_head + self.layers = layers + + def forward(self, x): + if self.layers == 1: + x = self.backbone.forward_features(x) + cls_token = x["x_norm_clstoken"] + patch_tokens = x["x_norm_patchtokens"] + # fmt: off + linear_input = torch.cat([ + cls_token, + patch_tokens.mean(dim=1), + ], dim=1) + # fmt: on + elif self.layers == 4: + x = self.backbone.get_intermediate_layers(x, n=4, return_class_token=True) + # fmt: off + linear_input = torch.cat([ + x[0][1], + x[1][1], + x[2][1], + x[3][1], + x[3][0].mean(dim=1), + ], dim=1) + # fmt: on + else: + assert False, f"Unsupported number of layers: {self.layers}" + return self.linear_head(linear_input) + + +def _make_dinov2_linear_classifier( + *, + arch_name: str = "vit_large", + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + **kwargs, +): + backbone = _make_dinov2_model( + arch_name=arch_name, + pretrained=pretrained, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + **kwargs, + ) + + embed_dim = backbone.embed_dim + patch_size = backbone.patch_size + linear_head = _make_dinov2_linear_classification_head( + arch_name=arch_name, + patch_size=patch_size, + embed_dim=embed_dim, + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=num_register_tokens, + ) + + return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head, layers=layers) + + +def dinov2_vits14_lc( + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_small", + layers=layers, + pretrained=pretrained, + weights=weights, + **kwargs, + ) + + +def dinov2_vitb14_lc( + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_base", + layers=layers, + pretrained=pretrained, + weights=weights, + **kwargs, + ) + + +def dinov2_vitl14_lc( + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_large", + layers=layers, + pretrained=pretrained, + weights=weights, + **kwargs, + ) + + +def dinov2_vitg14_lc( + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_giant2", + layers=layers, + ffn_layer="swiglufused", + pretrained=pretrained, + weights=weights, + **kwargs, + ) + + +def dinov2_vits14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_small", + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitb14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_base", + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitl14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_large", + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitg14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_giant2", + layers=layers, + ffn_layer="swiglufused", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) diff --git a/models/conditioner/dinov2/hub/depth/__init__.py b/models/conditioner/dinov2/hub/depth/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..91716e58ab6158d814df8c653644d9af4c7be65c --- /dev/null +++ b/models/conditioner/dinov2/hub/depth/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .decode_heads import BNHead, DPTHead +from .encoder_decoder import DepthEncoderDecoder diff --git a/models/conditioner/dinov2/hub/depth/decode_heads.py b/models/conditioner/dinov2/hub/depth/decode_heads.py new file mode 100644 index 0000000000000000000000000000000000000000..f455accad38fec6ecdd53460233a564c34f434da --- /dev/null +++ b/models/conditioner/dinov2/hub/depth/decode_heads.py @@ -0,0 +1,747 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import copy +from functools import partial +import math +import warnings + +import torch +import torch.nn as nn + +from .ops import resize + + +# XXX: (Untested) replacement for mmcv.imdenormalize() +def _imdenormalize(img, mean, std, to_bgr=True): + import numpy as np + + mean = mean.reshape(1, -1).astype(np.float64) + std = std.reshape(1, -1).astype(np.float64) + img = (img * std) + mean + if to_bgr: + img = img[::-1] + return img + + +class DepthBaseDecodeHead(nn.Module): + """Base class for BaseDecodeHead. + + Args: + in_channels (List): Input channels. + channels (int): Channels after modules, before conv_depth. + conv_layer (nn.Module): Conv layers. Default: None. + act_layer (nn.Module): Activation layers. Default: nn.ReLU. + loss_decode (dict): Config of decode loss. + Default: (). + sampler (dict|None): The config of depth map sampler. + Default: None. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + min_depth (int): Min depth in dataset setting. + Default: 1e-3. + max_depth (int): Max depth in dataset setting. + Default: None. + norm_layer (dict|None): Norm layers. + Default: None. + classify (bool): Whether predict depth in a cls.-reg. manner. + Default: False. + n_bins (int): The number of bins used in cls. step. + Default: 256. + bins_strategy (str): The discrete strategy used in cls. step. + Default: 'UD'. + norm_strategy (str): The norm strategy on cls. probability + distribution. Default: 'linear' + scale_up (str): Whether predict depth in a scale-up manner. + Default: False. + """ + + def __init__( + self, + in_channels, + conv_layer=None, + act_layer=nn.ReLU, + channels=96, + loss_decode=(), + sampler=None, + align_corners=False, + min_depth=1e-3, + max_depth=None, + norm_layer=None, + classify=False, + n_bins=256, + bins_strategy="UD", + norm_strategy="linear", + scale_up=False, + ): + super(DepthBaseDecodeHead, self).__init__() + + self.in_channels = in_channels + self.channels = channels + self.conf_layer = conv_layer + self.act_layer = act_layer + self.loss_decode = loss_decode + self.align_corners = align_corners + self.min_depth = min_depth + self.max_depth = max_depth + self.norm_layer = norm_layer + self.classify = classify + self.n_bins = n_bins + self.scale_up = scale_up + + if self.classify: + assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID" + assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid" + + self.bins_strategy = bins_strategy + self.norm_strategy = norm_strategy + self.softmax = nn.Softmax(dim=1) + self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1) + else: + self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1) + + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + + def forward(self, inputs, img_metas): + """Placeholder of forward function.""" + pass + + def forward_train(self, img, inputs, img_metas, depth_gt): + """Forward function for training. + Args: + inputs (list[Tensor]): List of multi-level img features. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + depth_gt (Tensor): GT depth + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + depth_pred = self.forward(inputs, img_metas) + losses = self.losses(depth_pred, depth_gt) + + log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0]) + losses.update(**log_imgs) + + return losses + + def forward_test(self, inputs, img_metas): + """Forward function for testing. + Args: + inputs (list[Tensor]): List of multi-level img features. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + + Returns: + Tensor: Output depth map. + """ + return self.forward(inputs, img_metas) + + def depth_pred(self, feat): + """Prediction each pixel.""" + if self.classify: + logit = self.conv_depth(feat) + + if self.bins_strategy == "UD": + bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device) + elif self.bins_strategy == "SID": + bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device) + + # following Adabins, default linear + if self.norm_strategy == "linear": + logit = torch.relu(logit) + eps = 0.1 + logit = logit + eps + logit = logit / logit.sum(dim=1, keepdim=True) + elif self.norm_strategy == "softmax": + logit = torch.softmax(logit, dim=1) + elif self.norm_strategy == "sigmoid": + logit = torch.sigmoid(logit) + logit = logit / logit.sum(dim=1, keepdim=True) + + output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1) + + else: + if self.scale_up: + output = self.sigmoid(self.conv_depth(feat)) * self.max_depth + else: + output = self.relu(self.conv_depth(feat)) + self.min_depth + return output + + def losses(self, depth_pred, depth_gt): + """Compute depth loss.""" + loss = dict() + depth_pred = resize( + input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False + ) + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + for loss_decode in losses_decode: + if loss_decode.loss_name not in loss: + loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt) + else: + loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt) + return loss + + def log_images(self, img_path, depth_pred, depth_gt, img_meta): + import numpy as np + + show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0)) + show_img = show_img.numpy().astype(np.float32) + show_img = _imdenormalize( + show_img, + img_meta["img_norm_cfg"]["mean"], + img_meta["img_norm_cfg"]["std"], + img_meta["img_norm_cfg"]["to_rgb"], + ) + show_img = np.clip(show_img, 0, 255) + show_img = show_img.astype(np.uint8) + show_img = show_img[:, :, ::-1] + show_img = show_img.transpose(0, 2, 1) + show_img = show_img.transpose(1, 0, 2) + + depth_pred = depth_pred / torch.max(depth_pred) + depth_gt = depth_gt / torch.max(depth_gt) + + depth_pred_color = copy.deepcopy(depth_pred.detach().cpu()) + depth_gt_color = copy.deepcopy(depth_gt.detach().cpu()) + + return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color} + + +class BNHead(DepthBaseDecodeHead): + """Just a batchnorm.""" + + def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs): + super().__init__(**kwargs) + self.input_transform = input_transform + self.in_index = in_index + self.upsample = upsample + # self.bn = nn.SyncBatchNorm(self.in_channels) + if self.classify: + self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1) + else: + self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1) + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + Args: + inputs (list[Tensor]): List of multi-level img features. + Returns: + Tensor: The transformed inputs + """ + + if "concat" in self.input_transform: + inputs = [inputs[i] for i in self.in_index] + if "resize" in self.input_transform: + inputs = [ + resize( + input=x, + size=[s * self.upsample for s in inputs[0].shape[2:]], + mode="bilinear", + align_corners=self.align_corners, + ) + for x in inputs + ] + inputs = torch.cat(inputs, dim=1) + elif self.input_transform == "multiple_select": + inputs = [inputs[i] for i in self.in_index] + else: + inputs = inputs[self.in_index] + + return inputs + + def _forward_feature(self, inputs, img_metas=None, **kwargs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + Args: + inputs (list[Tensor]): List of multi-level img features. + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + # accept lists (for cls token) + inputs = list(inputs) + for i, x in enumerate(inputs): + if len(x) == 2: + x, cls_token = x[0], x[1] + if len(x.shape) == 2: + x = x[:, :, None, None] + cls_token = cls_token[:, :, None, None].expand_as(x) + inputs[i] = torch.cat((x, cls_token), 1) + else: + x = x[0] + if len(x.shape) == 2: + x = x[:, :, None, None] + inputs[i] = x + x = self._transform_inputs(inputs) + # feats = self.bn(x) + return x + + def forward(self, inputs, img_metas=None, **kwargs): + """Forward function.""" + output = self._forward_feature(inputs, img_metas=img_metas, **kwargs) + output = self.depth_pred(output) + return output + + +class ConvModule(nn.Module): + """A conv block that bundles conv/norm/activation layers. + + This block simplifies the usage of convolution layers, which are commonly + used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU). + It is based upon three build methods: `build_conv_layer()`, + `build_norm_layer()` and `build_activation_layer()`. + + Besides, we add some additional features in this module. + 1. Automatically set `bias` of the conv layer. + 2. Spectral norm is supported. + 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only + supports zero and circular padding, and we add "reflect" padding mode. + + Args: + in_channels (int): Number of channels in the input feature map. + Same as that in ``nn._ConvNd``. + out_channels (int): Number of channels produced by the convolution. + Same as that in ``nn._ConvNd``. + kernel_size (int | tuple[int]): Size of the convolving kernel. + Same as that in ``nn._ConvNd``. + stride (int | tuple[int]): Stride of the convolution. + Same as that in ``nn._ConvNd``. + padding (int | tuple[int]): Zero-padding added to both sides of + the input. Same as that in ``nn._ConvNd``. + dilation (int | tuple[int]): Spacing between kernel elements. + Same as that in ``nn._ConvNd``. + groups (int): Number of blocked connections from input channels to + output channels. Same as that in ``nn._ConvNd``. + bias (bool | str): If specified as `auto`, it will be decided by the + norm_layer. Bias will be set as True if `norm_layer` is None, otherwise + False. Default: "auto". + conv_layer (nn.Module): Convolution layer. Default: None, + which means using conv2d. + norm_layer (nn.Module): Normalization layer. Default: None. + act_layer (nn.Module): Activation layer. Default: nn.ReLU. + inplace (bool): Whether to use inplace mode for activation. + Default: True. + with_spectral_norm (bool): Whether use spectral norm in conv module. + Default: False. + padding_mode (str): If the `padding_mode` has not been supported by + current `Conv2d` in PyTorch, we will use our own padding layer + instead. Currently, we support ['zeros', 'circular'] with official + implementation and ['reflect'] with our own implementation. + Default: 'zeros'. + order (tuple[str]): The order of conv/norm/activation layers. It is a + sequence of "conv", "norm" and "act". Common examples are + ("conv", "norm", "act") and ("act", "conv", "norm"). + Default: ('conv', 'norm', 'act'). + """ + + _abbr_ = "conv_block" + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias="auto", + conv_layer=nn.Conv2d, + norm_layer=None, + act_layer=nn.ReLU, + inplace=True, + with_spectral_norm=False, + padding_mode="zeros", + order=("conv", "norm", "act"), + ): + super(ConvModule, self).__init__() + official_padding_mode = ["zeros", "circular"] + self.conv_layer = conv_layer + self.norm_layer = norm_layer + self.act_layer = act_layer + self.inplace = inplace + self.with_spectral_norm = with_spectral_norm + self.with_explicit_padding = padding_mode not in official_padding_mode + self.order = order + assert isinstance(self.order, tuple) and len(self.order) == 3 + assert set(order) == set(["conv", "norm", "act"]) + + self.with_norm = norm_layer is not None + self.with_activation = act_layer is not None + # if the conv layer is before a norm layer, bias is unnecessary. + if bias == "auto": + bias = not self.with_norm + self.with_bias = bias + + if self.with_explicit_padding: + if padding_mode == "zeros": + padding_layer = nn.ZeroPad2d + else: + raise AssertionError(f"Unsupported padding mode: {padding_mode}") + self.pad = padding_layer(padding) + + # reset padding to 0 for conv module + conv_padding = 0 if self.with_explicit_padding else padding + # build convolution layer + self.conv = self.conv_layer( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=conv_padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + # export the attributes of self.conv to a higher level for convenience + self.in_channels = self.conv.in_channels + self.out_channels = self.conv.out_channels + self.kernel_size = self.conv.kernel_size + self.stride = self.conv.stride + self.padding = padding + self.dilation = self.conv.dilation + self.transposed = self.conv.transposed + self.output_padding = self.conv.output_padding + self.groups = self.conv.groups + + if self.with_spectral_norm: + self.conv = nn.utils.spectral_norm(self.conv) + + # build normalization layers + if self.with_norm: + # norm layer is after conv layer + if order.index("norm") > order.index("conv"): + norm_channels = out_channels + else: + norm_channels = in_channels + norm = partial(norm_layer, num_features=norm_channels) + self.add_module("norm", norm) + if self.with_bias: + from torch.nnModules.batchnorm import _BatchNorm + from torch.nnModules.instancenorm import _InstanceNorm + + if isinstance(norm, (_BatchNorm, _InstanceNorm)): + warnings.warn("Unnecessary conv bias before batch/instance norm") + else: + self.norm_name = None + + # build activation layer + if self.with_activation: + # nn.Tanh has no 'inplace' argument + # (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.HSigmoid, nn.Swish, nn.GELU) + if not isinstance(act_layer, (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.GELU)): + act_layer = partial(act_layer, inplace=inplace) + self.activate = act_layer() + + # Use msra init by default + self.init_weights() + + @property + def norm(self): + if self.norm_name: + return getattr(self, self.norm_name) + else: + return None + + def init_weights(self): + # 1. It is mainly for customized conv layers with their own + # initialization manners by calling their own ``init_weights()``, + # and we do not want ConvModule to override the initialization. + # 2. For customized conv layers without their own initialization + # manners (that is, they don't have their own ``init_weights()``) + # and PyTorch's conv layers, they will be initialized by + # this method with default ``kaiming_init``. + # Note: For PyTorch's conv layers, they will be overwritten by our + # initialization implementation using default ``kaiming_init``. + if not hasattr(self.conv, "init_weights"): + if self.with_activation and isinstance(self.act_layer, nn.LeakyReLU): + nonlinearity = "leaky_relu" + a = 0.01 # XXX: default negative_slope + else: + nonlinearity = "relu" + a = 0 + if hasattr(self.conv, "weight") and self.conv.weight is not None: + nn.init.kaiming_normal_(self.conv.weight, a=a, mode="fan_out", nonlinearity=nonlinearity) + if hasattr(self.conv, "bias") and self.conv.bias is not None: + nn.init.constant_(self.conv.bias, 0) + if self.with_norm: + if hasattr(self.norm, "weight") and self.norm.weight is not None: + nn.init.constant_(self.norm.weight, 1) + if hasattr(self.norm, "bias") and self.norm.bias is not None: + nn.init.constant_(self.norm.bias, 0) + + def forward(self, x, activate=True, norm=True): + for layer in self.order: + if layer == "conv": + if self.with_explicit_padding: + x = self.pad(x) + x = self.conv(x) + elif layer == "norm" and norm and self.with_norm: + x = self.norm(x) + elif layer == "act" and activate and self.with_activation: + x = self.activate(x) + return x + + +class Interpolate(nn.Module): + def __init__(self, scale_factor, mode, align_corners=False): + super(Interpolate, self).__init__() + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) + return x + + +class HeadDepth(nn.Module): + def __init__(self, features): + super(HeadDepth, self).__init__() + self.head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + ) + + def forward(self, x): + x = self.head(x) + return x + + +class ReassembleBlocks(nn.Module): + """ViTPostProcessBlock, process cls_token in ViT backbone output and + rearrange the feature vector to feature map. + Args: + in_channels (int): ViT feature channels. Default: 768. + out_channels (List): output channels of each stage. + Default: [96, 192, 384, 768]. + readout_type (str): Type of readout operation. Default: 'ignore'. + patch_size (int): The patch size. Default: 16. + """ + + def __init__(self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16): + super(ReassembleBlocks, self).__init__() + + assert readout_type in ["ignore", "add", "project"] + self.readout_type = readout_type + self.patch_size = patch_size + + self.projects = nn.ModuleList( + [ + ConvModule( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + act_layer=None, + ) + for out_channel in out_channels + ] + ) + + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 + ), + nn.ConvTranspose2d( + in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 + ), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 + ), + ] + ) + if self.readout_type == "project": + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU())) + + def forward(self, inputs): + assert isinstance(inputs, list) + out = [] + for i, x in enumerate(inputs): + assert len(x) == 2 + x, cls_token = x[0], x[1] + feature_shape = x.shape + if self.readout_type == "project": + x = x.flatten(2).permute((0, 2, 1)) + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + x = x.permute(0, 2, 1).reshape(feature_shape) + elif self.readout_type == "add": + x = x.flatten(2) + cls_token.unsqueeze(-1) + x = x.reshape(feature_shape) + else: + pass + x = self.projects[i](x) + x = self.resize_layers[i](x) + out.append(x) + return out + + +class PreActResidualConvUnit(nn.Module): + """ResidualConvUnit, pre-activate residual unit. + Args: + in_channels (int): number of channels in the input feature map. + act_layer (nn.Module): activation layer. + norm_layer (nn.Module): norm layer. + stride (int): stride of the first block. Default: 1 + dilation (int): dilation rate for convs layers. Default: 1. + """ + + def __init__(self, in_channels, act_layer, norm_layer, stride=1, dilation=1): + super(PreActResidualConvUnit, self).__init__() + + self.conv1 = ConvModule( + in_channels, + in_channels, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + norm_layer=norm_layer, + act_layer=act_layer, + bias=False, + order=("act", "conv", "norm"), + ) + + self.conv2 = ConvModule( + in_channels, + in_channels, + 3, + padding=1, + norm_layer=norm_layer, + act_layer=act_layer, + bias=False, + order=("act", "conv", "norm"), + ) + + def forward(self, inputs): + inputs_ = inputs.clone() + x = self.conv1(inputs) + x = self.conv2(x) + return x + inputs_ + + +class FeatureFusionBlock(nn.Module): + """FeatureFusionBlock, merge feature map from different stages. + Args: + in_channels (int): Input channels. + act_layer (nn.Module): activation layer for ResidualConvUnit. + norm_layer (nn.Module): normalization layer. + expand (bool): Whether expand the channels in post process block. + Default: False. + align_corners (bool): align_corner setting for bilinear upsample. + Default: True. + """ + + def __init__(self, in_channels, act_layer, norm_layer, expand=False, align_corners=True): + super(FeatureFusionBlock, self).__init__() + + self.in_channels = in_channels + self.expand = expand + self.align_corners = align_corners + + self.out_channels = in_channels + if self.expand: + self.out_channels = in_channels // 2 + + self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_layer=None, bias=True) + + self.res_conv_unit1 = PreActResidualConvUnit( + in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer + ) + self.res_conv_unit2 = PreActResidualConvUnit( + in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer + ) + + def forward(self, *inputs): + x = inputs[0] + if len(inputs) == 2: + if x.shape != inputs[1].shape: + res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) + else: + res = inputs[1] + x = x + self.res_conv_unit1(res) + x = self.res_conv_unit2(x) + x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners) + x = self.project(x) + return x + + +class DPTHead(DepthBaseDecodeHead): + """Vision Transformers for Dense Prediction. + This head is implemented of `DPT `_. + Args: + embed_dims (int): The embed dimension of the ViT backbone. + Default: 768. + post_process_channels (List): Out channels of post process conv + layers. Default: [96, 192, 384, 768]. + readout_type (str): Type of readout operation. Default: 'ignore'. + patch_size (int): The patch size. Default: 16. + expand_channels (bool): Whether expand the channels in post process + block. Default: False. + """ + + def __init__( + self, + embed_dims=768, + post_process_channels=[96, 192, 384, 768], + readout_type="ignore", + patch_size=16, + expand_channels=False, + **kwargs, + ): + super(DPTHead, self).__init__(**kwargs) + + self.in_channels = self.in_channels + self.expand_channels = expand_channels + self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size) + + self.post_process_channels = [ + channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels) + ] + self.convs = nn.ModuleList() + for channel in self.post_process_channels: + self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_layer=None, bias=False)) + self.fusion_blocks = nn.ModuleList() + for _ in range(len(self.convs)): + self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_layer, self.norm_layer)) + self.fusion_blocks[0].res_conv_unit1 = None + self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_layer=self.norm_layer) + self.num_fusion_blocks = len(self.fusion_blocks) + self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers) + self.num_post_process_channels = len(self.post_process_channels) + assert self.num_fusion_blocks == self.num_reassemble_blocks + assert self.num_reassemble_blocks == self.num_post_process_channels + self.conv_depth = HeadDepth(self.channels) + + def forward(self, inputs, img_metas): + assert len(inputs) == self.num_reassemble_blocks + x = [inp for inp in inputs] + x = self.reassemble_blocks(x) + x = [self.convs[i](feature) for i, feature in enumerate(x)] + out = self.fusion_blocks[0](x[-1]) + for i in range(1, len(self.fusion_blocks)): + out = self.fusion_blocks[i](out, x[-(i + 1)]) + out = self.project(out) + out = self.depth_pred(out) + return out diff --git a/models/conditioner/dinov2/hub/depth/encoder_decoder.py b/models/conditioner/dinov2/hub/depth/encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..eb29ced67957a336e763b0e7c90c0eeaea36fea8 --- /dev/null +++ b/models/conditioner/dinov2/hub/depth/encoder_decoder.py @@ -0,0 +1,351 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .ops import resize + + +def add_prefix(inputs, prefix): + """Add prefix for dict. + + Args: + inputs (dict): The input dict with str keys. + prefix (str): The prefix to add. + + Returns: + + dict: The dict with keys updated with ``prefix``. + """ + + outputs = dict() + for name, value in inputs.items(): + outputs[f"{prefix}.{name}"] = value + + return outputs + + +class DepthEncoderDecoder(nn.Module): + """Encoder Decoder depther. + + EncoderDecoder typically consists of backbone and decode_head. + """ + + def __init__(self, backbone, decode_head): + super(DepthEncoderDecoder, self).__init__() + + self.backbone = backbone + self.decode_head = decode_head + self.align_corners = self.decode_head.align_corners + + def extract_feat(self, img): + """Extract features from images.""" + return self.backbone(img) + + def encode_decode(self, img, img_metas, rescale=True, size=None): + """Encode images with backbone and decode into a depth estimation + map of the same size as input.""" + x = self.extract_feat(img) + out = self._decode_head_forward_test(x, img_metas) + # crop the pred depth to the certain range. + out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth) + if rescale: + if size is None: + if img_metas is not None: + size = img_metas[0]["ori_shape"][:2] + else: + size = img.shape[2:] + out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners) + return out + + def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs): + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, **kwargs) + losses.update(add_prefix(loss_decode, "decode")) + return losses + + def _decode_head_forward_test(self, x, img_metas): + """Run forward function and calculate loss for decode head in + inference.""" + depth_pred = self.decode_head.forward_test(x, img_metas) + return depth_pred + + def forward_dummy(self, img): + """Dummy forward function.""" + depth = self.encode_decode(img, None) + + return depth + + def forward_train(self, img, img_metas, depth_gt, **kwargs): + """Forward function for training. + + Args: + img (Tensor): Input images. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + depth_gt (Tensor): Depth gt + used if the architecture supports depth estimation task. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + x = self.extract_feat(img) + + losses = dict() + + # the last of x saves the info from neck + loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs) + + losses.update(loss_decode) + + return losses + + def whole_inference(self, img, img_meta, rescale, size=None): + """Inference with full image.""" + return self.encode_decode(img, img_meta, rescale, size=size) + + def slide_inference(self, img, img_meta, rescale, stride, crop_size): + """Inference by sliding-window with overlap. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + """ + + h_stride, w_stride = stride + h_crop, w_crop = crop_size + batch_size, _, h_img, w_img = img.size() + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = img.new_zeros((batch_size, 1, h_img, w_img)) + count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = img[:, :, y1:y2, x1:x2] + depth_pred = self.encode_decode(crop_img, img_meta, rescale) + preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + if torch.onnx.is_in_onnx_export(): + # cast count_mat to constant while exporting to ONNX + count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device) + preds = preds / count_mat + return preds + + def inference(self, img, img_meta, rescale, size=None, mode="whole"): + """Inference with slide/whole style. + + Args: + img (Tensor): The input image of shape (N, 3, H, W). + img_meta (dict): Image info dict where each dict has: 'img_shape', + 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + rescale (bool): Whether rescale back to original shape. + + Returns: + Tensor: The output depth map. + """ + + assert mode in ["slide", "whole"] + ori_shape = img_meta[0]["ori_shape"] + assert all(_["ori_shape"] == ori_shape for _ in img_meta) + if mode == "slide": + depth_pred = self.slide_inference(img, img_meta, rescale) + else: + depth_pred = self.whole_inference(img, img_meta, rescale, size=size) + output = depth_pred + flip = img_meta[0]["flip"] + if flip: + flip_direction = img_meta[0]["flip_direction"] + assert flip_direction in ["horizontal", "vertical"] + if flip_direction == "horizontal": + output = output.flip(dims=(3,)) + elif flip_direction == "vertical": + output = output.flip(dims=(2,)) + + return output + + def simple_test(self, img, img_meta, rescale=True): + """Simple test with single image.""" + depth_pred = self.inference(img, img_meta, rescale) + if torch.onnx.is_in_onnx_export(): + # our inference backend only support 4D output + depth_pred = depth_pred.unsqueeze(0) + return depth_pred + depth_pred = depth_pred.cpu().numpy() + # unravel batch dim + depth_pred = list(depth_pred) + return depth_pred + + def aug_test(self, imgs, img_metas, rescale=True): + """Test with augmentations. + + Only rescale=True is supported. + """ + # aug_test rescale all imgs back to ori_shape for now + assert rescale + # to save memory, we get augmented depth logit inplace + depth_pred = self.inference(imgs[0], img_metas[0], rescale) + for i in range(1, len(imgs)): + cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:]) + depth_pred += cur_depth_pred + depth_pred /= len(imgs) + depth_pred = depth_pred.cpu().numpy() + # unravel batch dim + depth_pred = list(depth_pred) + return depth_pred + + def forward_test(self, imgs, img_metas, **kwargs): + """ + Args: + imgs (List[Tensor]): the outer list indicates test-time + augmentations and inner Tensor should have a shape NxCxHxW, + which contains all images in the batch. + img_metas (List[List[dict]]): the outer list indicates test-time + augs (multiscale, flip, etc.) and the inner list indicates + images in a batch. + """ + for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]: + if not isinstance(var, list): + raise TypeError(f"{name} must be a list, but got " f"{type(var)}") + num_augs = len(imgs) + if num_augs != len(img_metas): + raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})") + # all images in the same aug batch all of the same ori_shape and pad + # shape + for img_meta in img_metas: + ori_shapes = [_["ori_shape"] for _ in img_meta] + assert all(shape == ori_shapes[0] for shape in ori_shapes) + img_shapes = [_["img_shape"] for _ in img_meta] + assert all(shape == img_shapes[0] for shape in img_shapes) + pad_shapes = [_["pad_shape"] for _ in img_meta] + assert all(shape == pad_shapes[0] for shape in pad_shapes) + + if num_augs == 1: + return self.simple_test(imgs[0], img_metas[0], **kwargs) + else: + return self.aug_test(imgs, img_metas, **kwargs) + + def forward(self, img, img_metas, return_loss=True, **kwargs): + """Calls either :func:`forward_train` or :func:`forward_test` depending + on whether ``return_loss`` is ``True``. + + Note this setting will change the expected inputs. When + ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor + and List[dict]), and when ``resturn_loss=False``, img and img_meta + should be double nested (i.e. List[Tensor], List[List[dict]]), with + the outer list indicating test time augmentations. + """ + if return_loss: + return self.forward_train(img, img_metas, **kwargs) + else: + return self.forward_test(img, img_metas, **kwargs) + + def train_step(self, data_batch, optimizer, **kwargs): + """The iteration step during training. + + This method defines an iteration step during training, except for the + back propagation and optimizer updating, which are done in an optimizer + hook. Note that in some complicated cases or models, the whole process + including back propagation and optimizer updating is also defined in + this method, such as GAN. + + Args: + data (dict): The output of dataloader. + optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of + runner is passed to ``train_step()``. This argument is unused + and reserved. + + Returns: + dict: It should contain at least 3 keys: ``loss``, ``log_vars``, + ``num_samples``. + ``loss`` is a tensor for back propagation, which can be a + weighted sum of multiple losses. + ``log_vars`` contains all the variables to be sent to the + logger. + ``num_samples`` indicates the batch size (when the model is + DDP, it means the batch size on each GPU), which is used for + averaging the logs. + """ + losses = self(**data_batch) + + # split losses and images + real_losses = {} + log_imgs = {} + for k, v in losses.items(): + if "img" in k: + log_imgs[k] = v + else: + real_losses[k] = v + + loss, log_vars = self._parse_losses(real_losses) + + outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs) + + return outputs + + def val_step(self, data_batch, **kwargs): + """The iteration step during validation. + + This method shares the same signature as :func:`train_step`, but used + during val epochs. Note that the evaluation after training epochs is + not implemented with this method, but an evaluation hook. + """ + output = self(**data_batch, **kwargs) + return output + + @staticmethod + def _parse_losses(losses): + import torch.distributed as dist + + """Parse the raw outputs (losses) of the network. + + Args: + losses (dict): Raw output of the network, which usually contain + losses and other necessary information. + + Returns: + tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor + which may be a weighted sum of all losses, log_vars contains + all the variables to be sent to the logger. + """ + log_vars = OrderedDict() + for loss_name, loss_value in losses.items(): + if isinstance(loss_value, torch.Tensor): + log_vars[loss_name] = loss_value.mean() + elif isinstance(loss_value, list): + log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) + else: + raise TypeError(f"{loss_name} is not a tensor or list of tensors") + + loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key) + + log_vars["loss"] = loss + for loss_name, loss_value in log_vars.items(): + # reduce loss when distributed training + if dist.is_available() and dist.is_initialized(): + loss_value = loss_value.data.clone() + dist.all_reduce(loss_value.div_(dist.get_world_size())) + log_vars[loss_name] = loss_value.item() + + return loss, log_vars diff --git a/models/conditioner/dinov2/hub/depth/ops.py b/models/conditioner/dinov2/hub/depth/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..15880ee0cb7652d4b41c489b927bf6a156b40e5e --- /dev/null +++ b/models/conditioner/dinov2/hub/depth/ops.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import warnings + +import torch.nn.functional as F + + +def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > output_h: + if ( + (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1) + and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1) + ): + warnings.warn( + f"When align_corners={align_corners}, " + "the output would more aligned if " + f"input size {(input_h, input_w)} is `x+1` and " + f"out size {(output_h, output_w)} is `nx+1`" + ) + return F.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/models/conditioner/dinov2/hub/depthers.py b/models/conditioner/dinov2/hub/depthers.py new file mode 100644 index 0000000000000000000000000000000000000000..f88b7e9a41056594e3b3e66107feee98bffab820 --- /dev/null +++ b/models/conditioner/dinov2/hub/depthers.py @@ -0,0 +1,246 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from functools import partial +from typing import Optional, Tuple, Union + +import torch + +from .backbones import _make_dinov2_model +from .depth import BNHead, DepthEncoderDecoder, DPTHead +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, CenterPadding + + +class Weights(Enum): + NYU = "NYU" + KITTI = "KITTI" + + +def _get_depth_range(pretrained: bool, weights: Weights = Weights.NYU) -> Tuple[float, float]: + if not pretrained: # Default + return (0.001, 10.0) + + # Pretrained, set according to the training dataset for the provided weights + if weights == Weights.KITTI: + return (0.001, 80.0) + + if weights == Weights.NYU: + return (0.001, 10.0) + + return (0.001, 10.0) + + +def _make_dinov2_linear_depth_head( + *, + embed_dim: int, + layers: int, + min_depth: float, + max_depth: float, + **kwargs, +): + if layers not in (1, 4): + raise AssertionError(f"Unsupported number of layers: {layers}") + + if layers == 1: + in_index = [0] + else: + assert layers == 4 + in_index = [0, 1, 2, 3] + + return BNHead( + classify=True, + n_bins=256, + bins_strategy="UD", + norm_strategy="linear", + upsample=4, + in_channels=[embed_dim] * len(in_index), + in_index=in_index, + input_transform="resize_concat", + channels=embed_dim * len(in_index) * 2, + align_corners=False, + min_depth=0.001, + max_depth=80, + loss_decode=(), + ) + + +def _make_dinov2_linear_depther( + *, + arch_name: str = "vit_large", + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.NYU, + depth_range: Optional[Tuple[float, float]] = None, + **kwargs, +): + if layers not in (1, 4): + raise AssertionError(f"Unsupported number of layers: {layers}") + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + if depth_range is None: + depth_range = _get_depth_range(pretrained, weights) + min_depth, max_depth = depth_range + + backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs) + + embed_dim = backbone.embed_dim + patch_size = backbone.patch_size + model_name = _make_dinov2_model_name(arch_name, patch_size) + linear_depth_head = _make_dinov2_linear_depth_head( + embed_dim=embed_dim, + layers=layers, + min_depth=min_depth, + max_depth=max_depth, + ) + + layer_count = { + "vit_small": 12, + "vit_base": 12, + "vit_large": 24, + "vit_giant2": 40, + }[arch_name] + + if layers == 4: + out_index = { + "vit_small": [2, 5, 8, 11], + "vit_base": [2, 5, 8, 11], + "vit_large": [4, 11, 17, 23], + "vit_giant2": [9, 19, 29, 39], + }[arch_name] + else: + assert layers == 1 + out_index = [layer_count - 1] + + model = DepthEncoderDecoder(backbone=backbone, decode_head=linear_depth_head) + model.backbone.forward = partial( + backbone.get_intermediate_layers, + n=out_index, + reshape=True, + return_class_token=True, + norm=False, + ) + model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(patch_size)(x[0])) + + if pretrained: + layers_str = str(layers) if layers == 4 else "" + weights_str = weights.value.lower() + url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_linear{layers_str}_head.pth" + checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu") + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + model.load_state_dict(state_dict, strict=False) + + return model + + +def dinov2_vits14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_linear_depther( + arch_name="vit_small", layers=layers, pretrained=pretrained, weights=weights, **kwargs + ) + + +def dinov2_vitb14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_linear_depther( + arch_name="vit_base", layers=layers, pretrained=pretrained, weights=weights, **kwargs + ) + + +def dinov2_vitl14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_linear_depther( + arch_name="vit_large", layers=layers, pretrained=pretrained, weights=weights, **kwargs + ) + + +def dinov2_vitg14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_linear_depther( + arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs + ) + + +def _make_dinov2_dpt_depth_head(*, embed_dim: int, min_depth: float, max_depth: float): + return DPTHead( + in_channels=[embed_dim] * 4, + channels=256, + embed_dims=embed_dim, + post_process_channels=[embed_dim // 2 ** (3 - i) for i in range(4)], + readout_type="project", + min_depth=min_depth, + max_depth=max_depth, + loss_decode=(), + ) + + +def _make_dinov2_dpt_depther( + *, + arch_name: str = "vit_large", + pretrained: bool = True, + weights: Union[Weights, str] = Weights.NYU, + depth_range: Optional[Tuple[float, float]] = None, + **kwargs, +): + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + if depth_range is None: + depth_range = _get_depth_range(pretrained, weights) + min_depth, max_depth = depth_range + + backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs) + + model_name = _make_dinov2_model_name(arch_name, backbone.patch_size) + dpt_depth_head = _make_dinov2_dpt_depth_head(embed_dim=backbone.embed_dim, min_depth=min_depth, max_depth=max_depth) + + out_index = { + "vit_small": [2, 5, 8, 11], + "vit_base": [2, 5, 8, 11], + "vit_large": [4, 11, 17, 23], + "vit_giant2": [9, 19, 29, 39], + }[arch_name] + + model = DepthEncoderDecoder(backbone=backbone, decode_head=dpt_depth_head) + model.backbone.forward = partial( + backbone.get_intermediate_layers, + n=out_index, + reshape=True, + return_class_token=True, + norm=False, + ) + model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(backbone.patch_size)(x[0])) + + if pretrained: + weights_str = weights.value.lower() + url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_dpt_head.pth" + checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu") + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + model.load_state_dict(state_dict, strict=False) + + return model + + +def dinov2_vits14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_dpt_depther(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitb14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_dpt_depther(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitl14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_dpt_depther(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitg14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_dpt_depther( + arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs + ) diff --git a/models/conditioner/dinov2/hub/utils.py b/models/conditioner/dinov2/hub/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9c6641404093652d5a2f19b4cf283d976ec39e64 --- /dev/null +++ b/models/conditioner/dinov2/hub/utils.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import itertools +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" + + +def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: + compact_arch_name = arch_name.replace("_", "")[:4] + registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" + return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" + + +class CenterPadding(nn.Module): + def __init__(self, multiple): + super().__init__() + self.multiple = multiple + + def _get_pad(self, size): + new_size = math.ceil(size / self.multiple) * self.multiple + pad_size = new_size - size + pad_size_left = pad_size // 2 + pad_size_right = pad_size - pad_size_left + return pad_size_left, pad_size_right + + @torch.inference_mode() + def forward(self, x): + pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])) + output = F.pad(x, pads) + return output diff --git a/models/conditioner/dinov2/layers/__init__.py b/models/conditioner/dinov2/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1f5f723b6997ab85ca330008090c9d3bc6853bee --- /dev/null +++ b/models/conditioner/dinov2/layers/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .dino_head import DINOHead +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused + +from .block import Block, BlockWithModulation +from .attention import MemEffAttention diff --git a/models/conditioner/dinov2/layers/attention.py b/models/conditioner/dinov2/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..0fb76ef2816164729a58cceb18d0f000cfb18777 --- /dev/null +++ b/models/conditioner/dinov2/layers/attention.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os +import warnings + +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (Attention)") + else: + warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + warnings.warn("xFormers is not available (Attention)") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/models/conditioner/dinov2/layers/block.py b/models/conditioner/dinov2/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..7bca0f0d4a46e668b6ed32c8a10df556b405996a --- /dev/null +++ b/models/conditioner/dinov2/layers/block.py @@ -0,0 +1,233 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + + +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import fmha, scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (Block)") + else: + warnings.warn("xFormers is disabled (Block)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + + warnings.warn("xFormers is not available (Block)") + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +# Override forward with modulation input +class BlockWithModulation(Block): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x: Tensor, mod: Tensor) -> Tensor: + def attn_residual_func(x: Tensor, mod: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x, mod))) + + def ffn_residual_func(x: Tensor, mod: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x, mod))) + + if self.training and self.sample_drop_ratio > 0.1: + raise NotImplementedError("Modulation with drop path ratio larger than 0.1 is not supported yet") + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x, mod)) + x = x + self.drop_path1(ffn_residual_func(x, mod)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x, mod) + x = x + ffn_residual_func(x, mod) + return x +# ******************************************************** + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs diff --git a/models/conditioner/dinov2/layers/dino_head.py b/models/conditioner/dinov2/layers/dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0ace8ffd6297a1dd480b19db407b662a6ea0f565 --- /dev/null +++ b/models/conditioner/dinov2/layers/dino_head.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/models/conditioner/dinov2/layers/drop_path.py b/models/conditioner/dinov2/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5 --- /dev/null +++ b/models/conditioner/dinov2/layers/drop_path.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/models/conditioner/dinov2/layers/layer_scale.py b/models/conditioner/dinov2/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..51df0d7ce61f2b41fa9e6369f52391dd7fe7d386 --- /dev/null +++ b/models/conditioner/dinov2/layers/layer_scale.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/models/conditioner/dinov2/layers/mlp.py b/models/conditioner/dinov2/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e --- /dev/null +++ b/models/conditioner/dinov2/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/models/conditioner/dinov2/layers/patch_embed.py b/models/conditioner/dinov2/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..8b7c0804784a42cf80c0297d110dcc68cc85b339 --- /dev/null +++ b/models/conditioner/dinov2/layers/patch_embed.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/models/conditioner/dinov2/layers/swiglu_ffn.py b/models/conditioner/dinov2/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..5e9dafa4592a408f6874d54853e8f60db5c41f74 --- /dev/null +++ b/models/conditioner/dinov2/layers/swiglu_ffn.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional +import warnings + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (SwiGLU)") + else: + warnings.warn("xFormers is disabled (SwiGLU)") + raise ImportError +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/models/conditioner/dinov2/models/__init__.py b/models/conditioner/dinov2/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3fdff20badbd5244bf79f16bf18dd2cb73982265 --- /dev/null +++ b/models/conditioner/dinov2/models/__init__.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging + +from . import vision_transformer as vits + + +logger = logging.getLogger("dinov2") + + +def build_model(args, only_teacher=False, img_size=224): + args.arch = args.arch.removesuffix("_memeff") + if "vit" in args.arch: + vit_kwargs = dict( + img_size=img_size, + patch_size=args.patch_size, + init_values=args.layerscale, + ffn_layer=args.ffn_layer, + block_chunks=args.block_chunks, + qkv_bias=args.qkv_bias, + proj_bias=args.proj_bias, + ffn_bias=args.ffn_bias, + num_register_tokens=args.num_register_tokens, + interpolate_offset=args.interpolate_offset, + interpolate_antialias=args.interpolate_antialias, + ) + teacher = vits.__dict__[args.arch](**vit_kwargs) + if only_teacher: + return teacher, teacher.embed_dim + student = vits.__dict__[args.arch]( + **vit_kwargs, + drop_path_rate=args.drop_path_rate, + drop_path_uniform=args.drop_path_uniform, + ) + embed_dim = student.embed_dim + return student, teacher, embed_dim + + +def build_model_from_cfg(cfg, only_teacher=False): + return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size) diff --git a/models/conditioner/dinov2/models/vision_transformer.py b/models/conditioner/dinov2/models/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..bb1a40c272758e705061aaca41de35d78f29f557 --- /dev/null +++ b/models/conditioner/dinov2/models/vision_transformer.py @@ -0,0 +1,418 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ +from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, Block, BlockWithModulation +# ******************************************************** + + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + modulation_dim: int = None, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + + block_norm_layer = None + if modulation_dim is not None: + from ....modulate import ModLN + block_norm_layer = partial(ModLN, mod_dim=modulation_dim) + else: + block_norm_layer = nn.LayerNorm + block_norm_layer = partial(block_norm_layer, eps=1e-6) + # ******************************************************** + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=block_norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset + + sqrt_N = math.sqrt(N) + sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), + scale_factor=(sx, sy), + mode="bicubic", + antialias=self.interpolate_antialias, + ) + + assert int(w0) == patch_pos_embed.shape[-2] + assert int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + raise NotImplementedError("Masking is not supported in hacked DINOv2") + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None, mod=None): + if isinstance(x, list): + raise DeprecationWarning("forward_features_list is deprecated, use forward_features") + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + if mod is None: + for blk in self.blocks: + x = blk(x) + else: + for blk in self.blocks: + x = blk(x, mod) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + # ******************************************************** + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def _block_cls(**kwargs): + modulation_dim = kwargs.get("modulation_dim", None) + if modulation_dim is None: + block_cls = Block + else: + block_cls = BlockWithModulation + return block_cls + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(_block_cls(**kwargs), attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(_block_cls(**kwargs), attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(_block_cls(**kwargs), attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(_block_cls(**kwargs), attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + +# ******************************************************** diff --git a/models/conditioner/image.py b/models/conditioner/image.py new file mode 100644 index 0000000000000000000000000000000000000000..6d7b43144a63d23c6748738d3ce933a2a335e82d --- /dev/null +++ b/models/conditioner/image.py @@ -0,0 +1,292 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import Compose, Resize, CenterCrop, Normalize, InterpolationMode + +import open_clip +from dva.io import load_from_config + +def sample_orbit_traj(radius, height, start_theta, end_theta, num_points, world_up=torch.Tensor([0, 1, 0])): + # return [num_points, 3, 4] + angles = torch.rand((num_points, )) * (end_theta - start_theta) + start_theta + return get_pose_on_orbit(radius=radius, height=height, angles=angles, world_up=world_up) + +def get_pose_on_orbit(radius, height, angles, world_up=torch.Tensor([0, 1, 0])): + num_points = angles.shape[0] + x = radius * torch.cos(angles) + h = torch.ones((num_points,)) * height + z = radius * torch.sin(angles) + position = torch.stack([x, h, z], dim=-1) + forward = position / torch.norm(position, p=2, dim=-1, keepdim=True) + right = -torch.cross(world_up[None, ...], forward) + right /= torch.norm(right, dim=-1, keepdim=True) + up = torch.cross(forward, right) + up /= torch.norm(up, p=2, dim=-1, keepdim=True) + rotation = torch.stack([right, up, forward], dim=1) + translation = torch.Tensor([0, 0, radius])[None, :, None].repeat(num_points, 1, 1) + return torch.concat([rotation, translation], dim=2) + +class DummyImageConditioner(nn.Module): + def __init__( + self, + num_prims, + dim_feat, + prim_shape, + encoder_config, + sample_view=False, + sample_start=torch.pi*0.25, + sample_end=torch.pi*0.75, + ): + super().__init__() + + self.num_prims = num_prims + self.dim_feat = dim_feat + self.prim_shape = prim_shape + self.sample_view = sample_view + self.sample_start = sample_start + self.sample_end = sample_end + self.encoder = None + + @torch.no_grad() + def forward(self, batch, rm, amp, precision_dtype=torch.float32): + return batch['cond'] + +class ImageConditioner(nn.Module): + def __init__( + self, + num_prims, + dim_feat, + prim_shape, + encoder_config, + sample_view=False, + sample_start=torch.pi*0.25, + sample_end=torch.pi*0.75, + ): + super().__init__() + + self.num_prims = num_prims + self.dim_feat = dim_feat + self.prim_shape = prim_shape + self.sample_view = sample_view + self.sample_start = sample_start + self.sample_end = sample_end + self.encoder = load_from_config(encoder_config) + + def sdf2alpha(self, sdf): + return torch.exp(-(sdf / 0.005) ** 2) + + @torch.no_grad() + def forward(self, batch, rm, amp, precision_dtype=torch.float32): + # TODO: replace with real rendering process in primsdf + assert 'input_param' in batch, "No parameters in current batch for rendering image conditions" + prim_volume = batch['input_param'] + bs = prim_volume.shape[0] + preds = {} + geo_start_index = 4 + geo_end_index = geo_start_index + self.prim_shape ** 3 # non-inclusive + tex_start_index = geo_end_index + tex_end_index = tex_start_index + self.prim_shape ** 3 * 3 # non-inclusive + feat_geo = prim_volume[:, :, geo_start_index: geo_end_index] + feat_tex = prim_volume[:, :, tex_start_index: tex_end_index] + prim_alpha = self.sdf2alpha(feat_geo).reshape(bs, self.num_prims, 1, self.prim_shape, self.prim_shape, self.prim_shape) * 255 + prim_rgb = feat_tex.reshape(bs, self.num_prims, 3, self.prim_shape, self.prim_shape, self.prim_shape) * 255 + preds['prim_rgba'] = torch.concat([prim_rgb, prim_alpha], dim=2) + pos = prim_volume[:, :, 1:4] + scale = prim_volume[:, :, 0:1] + preds['prim_pos'] = pos.reshape(bs, self.num_prims, 3) * rm.volradius + preds['prim_rot'] = torch.eye(3).to(preds['prim_pos'])[None, None, ...].repeat(bs, self.num_prims, 1, 1) + preds['prim_scale'] = (1 / scale.reshape(bs, self.num_prims, 1).repeat(1, 1, 3)) + if not self.sample_view: + preds['Rt'] = torch.Tensor([ + [ + 1.0, + 0.0, + 0.0, + 0.0 * rm.volradius + ], + [ + 0.0, + -1.0, + 0.0, + 0.0 * rm.volradius + ], + [ + 0.0, + 0.0, + -1.0, + 5 * rm.volradius + ] + ]).to(prim_volume)[None, ...].repeat(bs, 1, 1) + else: + preds['Rt'] = sample_orbit_traj(radius=5*rm.volradius, height=0, start_theta=self.sample_start, end_theta=self.sample_end, num_points=bs).to(prim_volume) + preds['K'] = torch.Tensor([ + [ + 2084.9526697685183, + 0.0, + 512.0 + ], + [ + 0.0, + 2084.9526697685183, + 512.0 + ], + [ + 0.0, + 0.0, + 1.0 + ]]).to(prim_volume)[None, ...].repeat(bs, 1, 1) + ratio_h = rm.image_height / 1024. + ratio_w = rm.image_width / 1024. + preds['K'][:, 0:1, :] *= ratio_h + preds['K'][:, 1:2, :] *= ratio_w + rm_preds = rm( + prim_rgba=preds["prim_rgba"], + prim_pos=preds["prim_pos"], + prim_scale=preds["prim_scale"], + prim_rot=preds["prim_rot"], + RT=preds["Rt"], + K=preds["K"], + ) + rendered_image = rm_preds['rgba_image'].permute(0, 2, 3, 1)[..., :3].contiguous() + with torch.autocast(device_type='cuda', dtype=precision_dtype, enabled=amp): + results = self.encoder(rendered_image) + return results + +class ImageMultiViewConditioner(nn.Module): + def __init__( + self, + num_prims, + dim_feat, + prim_shape, + encoder_config, + sample_view=False, + view_counts=4, + ): + super().__init__() + + self.num_prims = num_prims + self.dim_feat = dim_feat + self.prim_shape = prim_shape + self.view_counts = view_counts + view_angles = torch.linspace(0.5, 2.5, self.view_counts + 1) * torch.pi + self.view_angles = view_angles[:-1] + self.encoder = load_from_config(encoder_config) + + def sdf2alpha(self, sdf): + return torch.exp(-(sdf / 0.005) ** 2) + + @torch.no_grad() + def forward(self, batch, rm, amp, precision_dtype=torch.float32): + # TODO: replace with real rendering process in primsdf + assert 'input_param' in batch, "No parameters in current batch for rendering image conditions" + prim_volume = batch['input_param'] + bs = prim_volume.shape[0] + preds = {} + geo_start_index = 4 + geo_end_index = geo_start_index + self.prim_shape ** 3 # non-inclusive + tex_start_index = geo_end_index + tex_end_index = tex_start_index + self.prim_shape ** 3 * 3 # non-inclusive + feat_geo = prim_volume[:, :, geo_start_index: geo_end_index] + feat_tex = prim_volume[:, :, tex_start_index: tex_end_index] + prim_alpha = self.sdf2alpha(feat_geo).reshape(bs, self.num_prims, 1, self.prim_shape, self.prim_shape, self.prim_shape) * 255 + prim_rgb = feat_tex.reshape(bs, self.num_prims, 3, self.prim_shape, self.prim_shape, self.prim_shape) * 255 + preds['prim_rgba'] = torch.concat([prim_rgb, prim_alpha], dim=2) + pos = prim_volume[:, :, 1:4] + scale = prim_volume[:, :, 0:1] + preds['prim_pos'] = pos.reshape(bs, self.num_prims, 3) * rm.volradius + preds['prim_rot'] = torch.eye(3).to(preds['prim_pos'])[None, None, ...].repeat(bs, self.num_prims, 1, 1) + preds['prim_scale'] = (1 / scale.reshape(bs, self.num_prims, 1).repeat(1, 1, 3)) + preds['K'] = torch.Tensor([ + [ + 2084.9526697685183, + 0.0, + 512.0 + ], + [ + 0.0, + 2084.9526697685183, + 512.0 + ], + [ + 0.0, + 0.0, + 1.0 + ]]).to(prim_volume)[None, ...].repeat(bs, 1, 1) + ratio_h = rm.image_height / 1024. + ratio_w = rm.image_width / 1024. + preds['K'][:, 0:1, :] *= ratio_h + preds['K'][:, 1:2, :] *= ratio_w + # we sample view according to view_counts + cond_list = [] + for view_ang in self.view_angles: + bs_view_ang = view_ang.repeat(bs,) + preds['Rt'] = get_pose_on_orbit(radius=5*rm.volradius, height=0, angles=bs_view_ang).to(prim_volume) + rm_preds = rm( + prim_rgba=preds["prim_rgba"], + prim_pos=preds["prim_pos"], + prim_scale=preds["prim_scale"], + prim_rot=preds["prim_rot"], + RT=preds["Rt"], + K=preds["K"], + ) + rendered_image = rm_preds['rgba_image'].permute(0, 2, 3, 1)[..., :3].contiguous() + with torch.autocast(device_type='cuda', dtype=precision_dtype, enabled=amp): + results = self.encoder(rendered_image) + cond_list.append(results) + final_cond = torch.concat(cond_list, dim=1) + return final_cond + +class CLIPImageEncoder(nn.Module): + def __init__( + self, + pretrained_path: str, + model_spec: str = 'ViT-L-14', + ): + super().__init__() + + self.model, _, _ = open_clip.create_model_and_transforms(model_spec, pretrained=pretrained_path) + self.model_resolution = self.model.visual.image_size + self.preprocess = Compose([ + Resize(self.model_resolution, interpolation=InterpolationMode.BICUBIC), + CenterCrop(self.model_resolution), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + self.model.eval() + # self.tokenizer = open_clip.get_tokenizer(model_spec) + + @torch.no_grad() + def forward(self, img): + assert img.shape[-1] == 3 + img = img.permute(0, 3, 1, 2) / 255. + image = self.preprocess(img) + image_features = self.model.encode_image(image) + image_features /= image_features.norm(dim=-1, keepdim=True) + return image_features + +class CLIPImageTokenEncoder(nn.Module): + def __init__( + self, + pretrained_path: str, + model_spec: str = 'ViT-L-14', + ): + super().__init__() + + self.model, _, _ = open_clip.create_model_and_transforms(model_spec, pretrained=pretrained_path) + self.model.visual.output_tokens = True + self.model_resolution = self.model.visual.image_size + self.preprocess = Compose([ + Resize(self.model_resolution, interpolation=InterpolationMode.BICUBIC), + CenterCrop(self.model_resolution), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + self.model.eval() + + @torch.no_grad() + def forward(self, img): + assert img.shape[-1] == 3 + img = img.permute(0, 3, 1, 2) / 255. + image = self.preprocess(img) + _, image_tokens = self.model.encode_image(image) + # [B, T, D] - [B, 256, 1024] + image_tokens /= image_tokens.norm(dim=-1, keepdim=True) + return image_tokens \ No newline at end of file diff --git a/models/conditioner/image_dinov2.py b/models/conditioner/image_dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..aff157c4c0d97812a6b2f8f5252744b30dff1c49 --- /dev/null +++ b/models/conditioner/image_dinov2.py @@ -0,0 +1,61 @@ +import torch +import torch.nn as nn +from torchvision.transforms import Compose, Resize, InterpolationMode, Normalize + +import logging +logger = logging.getLogger(__name__) + + + + +class Dinov2Wrapper(nn.Module): + """ + Dino v2 wrapper using original implementation, hacked with modulation. + """ + def __init__(self, model_name: str, modulation_dim: int = None, freeze: bool = True): + super().__init__() + self.modulation_dim = modulation_dim + self.model = self._build_dinov2(model_name, modulation_dim=modulation_dim) + self.preprocess = Compose([ + Resize(self.model.patch_embed.img_size[0], interpolation=InterpolationMode.BICUBIC), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + if freeze: + if modulation_dim is not None: + raise ValueError("Modulated Dinov2 requires training, freezing is not allowed.") + self._freeze() + + def _freeze(self): + logger.warning(f"======== Freezing Dinov2Wrapper ========") + self.model.eval() + for name, param in self.model.named_parameters(): + param.requires_grad = False + + @staticmethod + def _build_dinov2(model_name: str, modulation_dim: int = None, pretrained: bool = True): + from importlib import import_module + dinov2_hub = import_module(".dinov2.hub.backbones", package=__package__) + model_fn = getattr(dinov2_hub, model_name) + logger.info(f"Modulation dim for Dinov2 is {modulation_dim}.") + model = model_fn(modulation_dim=modulation_dim, pretrained=pretrained) + return model + + # @torch.compile + def forward(self, image: torch.Tensor, mod: torch.Tensor = None): + # image: [N, H, W, C] -- need to be permuted!!! + # mod: [N, D] or None + assert image.shape[-1] == 3 + image = image.permute(0, 3, 1, 2) / 255. + image = self.preprocess(image) + if self.modulation_dim is None: + assert mod is None, "Unexpected modulation input in dinov2 forward." + outs = self.model(image, is_training=True) + else: + assert mod is not None, "Modulation input is required in modulated dinov2 forward." + outs = self.model(image, mod=mod, is_training=True) + ret = torch.cat([ + outs["x_norm_clstoken"].unsqueeze(dim=1), + outs["x_norm_patchtokens"], + ], dim=1) + # ret in [B, 1370, 384] + return ret diff --git a/models/conditioner/text.py b/models/conditioner/text.py new file mode 100644 index 0000000000000000000000000000000000000000..1b889d9acb911884660cb63579f66d777090f201 --- /dev/null +++ b/models/conditioner/text.py @@ -0,0 +1,38 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import open_clip +from dva.io import load_from_config + +class TextConditioner(nn.Module): + def __init__( + self, + encoder_config, + ): + super().__init__() + self.encoder = load_from_config(encoder_config) + + @torch.no_grad() + def forward(self, batch, rm, amp=False, precision_dtype=torch.float32): + assert 'caption_token' in batch, "No tokenized caption in current batch for text conditions" + caption_token = batch['caption_token'] + with torch.autocast(device_type='cuda', dtype=precision_dtype, enabled=amp): + results = self.encoder(caption_token) + return results + +class CLIPTextEncoder(nn.Module): + def __init__( + self, + pretrained_path: str, + model_spec: str = 'ViT-L-14', + ): + super().__init__() + self.model, _, _ = open_clip.create_model_and_transforms(model_spec, pretrained=pretrained_path) + self.model.eval() + + @torch.no_grad() + def forward(self, text): + text_features = self.model.encode_text(text) + text_features /= text_features.norm(dim=-1, keepdim=True) + return text_features[:, None, :] diff --git a/models/diffusion/__init__.py b/models/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..65e285e53d913f4638b1c5ebe82a1d02aff6a341 --- /dev/null +++ b/models/diffusion/__init__.py @@ -0,0 +1,52 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +from . import gaussian_diffusion as gd +from .respace import SpacedDiffusion, space_timesteps + + +def create_diffusion( + timestep_respacing, + noise_schedule="linear", + use_kl=False, + sigma_small=False, + parameterization="eps", + learn_sigma=True, + rescale_learned_sigmas=False, + diffusion_steps=1000 +): + betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if timestep_respacing is None or timestep_respacing == "": + timestep_respacing = [diffusion_steps] + if parameterization == "eps": + model_mean_type = gd.ModelMeanType.EPSILON + elif parameterization == "xstart": + model_mean_type = gd.ModelMeanType.START_X + elif parameterization == "v": + model_mean_type = gd.ModelMeanType.VELOCITY + else: + raise NotImplementedError("Model Mean Type {} is not supported!".format(parameterization)) + return SpacedDiffusion( + use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), + betas=betas, + model_mean_type=model_mean_type, + model_var_type=( + ( + gd.ModelVarType.FIXED_LARGE + if not sigma_small + else gd.ModelVarType.FIXED_SMALL + ) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type + # rescale_timesteps=rescale_timesteps, + ) diff --git a/models/diffusion/diffusion_utils.py b/models/diffusion/diffusion_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e493a6a3ecb91e553a53cc7eadee5cc0d1753060 --- /dev/null +++ b/models/diffusion/diffusion_utils.py @@ -0,0 +1,88 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +import torch as th +import numpy as np + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [ + x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + th.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * th.exp(-logvar2) + ) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def continuous_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a continuous Gaussian distribution. + :param x: the targets + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + centered_x = x - means + inv_stdv = th.exp(-log_scales) + normalized_x = centered_x * inv_stdv + log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) + return log_probs + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs diff --git a/models/diffusion/gaussian_diffusion.py b/models/diffusion/gaussian_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..c18bd9557d99ff3af427608dd74b6efad84f5d2e --- /dev/null +++ b/models/diffusion/gaussian_diffusion.py @@ -0,0 +1,892 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + + +import math + +import numpy as np +import torch as th +import enum + +from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + VELOCITY = enum.auto() + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + warmup_time = int(num_diffusion_timesteps * warmup_frac) + betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) + return betas + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + """ + This is the deprecated API for creating beta schedules. + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start ** 0.5, + beta_end ** 0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "warmup10": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) + elif beta_schedule == "warmup50": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace( + num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 + ) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + "linear", + beta_start=scale * 0.0001, + beta_end=scale * 0.02, + num_diffusion_timesteps=num_diffusion_timesteps, + ) + elif schedule_name == "squaredcos_cap_v2": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + Original ported from this codebase: + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type + ): + + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) if len(self.posterior_variance) > 1 else np.array([]) + + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) + ) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + In other words, sample from q(x_t | x_0). + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x T x D] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + B, nt, C = x.shape + assert t.shape == (B,) + model_output = model(x, t, **model_kwargs) + if isinstance(model_output, tuple): + model_output, extra = model_output + else: + extra = None + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, nt ,C * 2) + model_output, model_var_values = th.split(model_output, C, dim=2) + min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + elif self.model_mean_type == ModelMeanType.EPSILON: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + ) + elif self.model_mean_type == ModelMeanType.VELOCITY: + pred_xstart = process_xstart( + self._predict_xstart_from_z_and_v(x_t=x, t=t, v=model_output) + ) + else: + raise NotImplementedError("Model Mean type {} is not supported!".format(self.model_mean_type)) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + "extra": extra, + } + + def _predict_xstart_from_z_and_v(self, x_t, t, v): + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def get_v(self, x, noise, t): + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x + ) + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, **model_kwargs) + new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + See condition_mean() for details on cond_fn. + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model. + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + def _vb_terms_bpd( + self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None + ): + """ + Get a term for the variational lower-bound. + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + out = self.p_mean_variance( + model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + kl = normal_kl( + true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] + ) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): + """ + Compute training losses for a single timestep. + :param model: the model to evaluate loss on. + :param x_start: the [N x T x D] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss_total"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss_total"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_output = model(x_t, t, **model_kwargs) + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, nt, C = x_t.shape + assert model_output.shape == (B, nt, C * 2) + model_output, model_var_values = th.split(model_output, C, dim=2) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=2) + terms["loss_vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["loss_vb"] *= self.num_timesteps / 1000.0 + + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + )[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + ModelMeanType.VELOCITY: self.get_v(x_start, noise, t) + }[self.model_mean_type] + assert model_output.shape == target.shape == x_start.shape + terms["loss_mse"] = mean_flat((target - model_output) ** 2) + if "loss_vb" in terms: + terms["loss_total"] = terms["loss_mse"] + terms["loss_vb"] + else: + terms["loss_total"] = terms["loss_mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl( + mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 + ) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + th.zeros(broadcast_shape, device=timesteps.device) diff --git a/models/diffusion/respace.py b/models/diffusion/respace.py new file mode 100644 index 0000000000000000000000000000000000000000..0a2cc0435d1ace54466585db9043b284973d454e --- /dev/null +++ b/models/diffusion/respace.py @@ -0,0 +1,129 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +import numpy as np +import torch as th + +from .gaussian_diffusion import GaussianDiffusion + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel( + model, self.timestep_map, self.original_num_steps + ) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, original_num_steps): + self.model = model + self.timestep_map = timestep_map + # self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + # if self.rescale_timesteps: + # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) diff --git a/models/diffusion/timestep_sampler.py b/models/diffusion/timestep_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..a3f369847677d8dbaaadb8297691b1be92cf189f --- /dev/null +++ b/models/diffusion/timestep_sampler.py @@ -0,0 +1,150 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +from abc import ABC, abstractmethod + +import numpy as np +import torch as th +import torch.distributed as dist + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + elif name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = th.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion): + self.diffusion = diffusion + self._weights = np.ones([diffusion.num_timesteps]) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts, local_losses): + """ + Update the reweighting using losses from a model. + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + batch_sizes = [ + th.tensor([0], dtype=th.int32, device=local_ts.device) + for _ in range(dist.get_world_size()) + ] + dist.all_gather( + batch_sizes, + th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [ + x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] + ] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + Sub-classes should override this method to update the reweighting + using losses from the model. + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros( + [diffusion.num_timesteps, history_per_term], dtype=np.float64 + ) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() diff --git a/models/dit_crossattn.py b/models/dit_crossattn.py new file mode 100644 index 0000000000000000000000000000000000000000..6576f8dfc46be3fcadadd6d7600179d48f3253c2 --- /dev/null +++ b/models/dit_crossattn.py @@ -0,0 +1,301 @@ +# A modified version of DiT (Diffusion Transformer) to support directly dealing with 3D primitives with shape of [batch_size, sequence_length, dim_feat] + +# -------------------------------------------------------- +# References: +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# DiT: https://github.dev/facebookresearch/DiT +# -------------------------------------------------------- + +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +import numpy as np +import math +from itertools import repeat +import collections.abc +from .attention import MemEffCrossAttention, MemEffAttention +from .utils import TimestepEmbedder, Mlp, modulate + + +################################################################################# +# Core DiT Model # +################################################################################# + +class DiTBlock(nn.Module): + """ + A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. + """ + def __init__(self, hidden_size, cross_attn_cond_dim, num_heads, mlp_ratio=4.0, proj_bias=False, gradient_checkpointing=False, **block_kwargs): + super().__init__() + self.gradient_checkpointing = gradient_checkpointing + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.crossattn = MemEffCrossAttention(dim=hidden_size, dim_q=hidden_size, dim_k=cross_attn_cond_dim, dim_v=cross_attn_cond_dim, num_heads=num_heads, qkv_bias=True, proj_bias=proj_bias, attn_drop=0.0, proj_drop=0.0, gradient_checkpointing=gradient_checkpointing, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = MemEffAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=True, proj_bias=proj_bias, attn_drop=0.0, proj_drop=0.0, gradient_checkpointing=gradient_checkpointing, **block_kwargs) + self.norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 9 * hidden_size, bias=True) + ) + + def forward(self, x, cross_attn_cond, mod_cond): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, x, cross_attn_cond, mod_cond, use_reentrant=False) + else: + return self._forward(x, cross_attn_cond, mod_cond) + + def _forward(self, x, cross_attn_cond, mod_cond): + # cross_attn_cond: conditions that use cross attention to cond, would be image tokens typically [B, L_cond, D_cond] + # mod_cond: conditions that uses modulation to cond, would be timestep typically [B, D_mod] + shift_mca, scale_mca, gate_mca, shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod_cond).chunk(9, dim=1) + x = x + gate_mca.unsqueeze(1) * self.crossattn(modulate(self.norm1(x), shift_mca, scale_mca), cross_attn_cond, cross_attn_cond) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm2(x), shift_msa, scale_msa)) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm3(x), shift_mlp, scale_mlp)) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of DiT. + """ + def __init__(self, hidden_size, seq_length, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class PointEmbed(nn.Module): + def __init__(self, hidden_dim=48, dim=128): + super().__init__() + + assert hidden_dim % 6 == 0 + + self.embedding_dim = hidden_dim + e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi + e = torch.stack([ + torch.cat([e, torch.zeros(self.embedding_dim // 6), + torch.zeros(self.embedding_dim // 6)]), + torch.cat([torch.zeros(self.embedding_dim // 6), e, + torch.zeros(self.embedding_dim // 6)]), + torch.cat([torch.zeros(self.embedding_dim // 6), + torch.zeros(self.embedding_dim // 6), e]), + ]) + self.register_buffer('basis', e) # 3 x 16 + + self.mlp = nn.Linear(self.embedding_dim+3, dim) + + @staticmethod + def embed(input, basis): + projections = torch.einsum('bnd,de->bne', input, basis) + embeddings = torch.cat([projections.sin(), projections.cos()], dim=2) + return embeddings + + def forward(self, input): + # input: B x N x 3 + embed = self.mlp(torch.cat([self.embed(input, self.basis), input], dim=2)) # B x N x C + return embed + +class DiT(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + def __init__( + self, + seq_length=2, + in_channels=4, + condition_channels=512, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + cond_drop_prob=0.0, + attn_proj_bias=False, + learn_sigma=True, + gradient_checkpointing=False, + ): + super().__init__() + self.gradient_checkpointing = gradient_checkpointing + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.seq_length = seq_length + self.num_heads = num_heads + self.cond_drop_prob = cond_drop_prob + if self.cond_drop_prob > 0: + self.null_cond_embedding = nn.Parameter(torch.randn(condition_channels)) + + # no need to patchify as prim representation is already patch-wise + self.x_embedder = nn.Linear(in_channels, hidden_size) + # self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + + self.blocks = nn.ModuleList([ + DiTBlock(hidden_size, condition_channels, num_heads, mlp_ratio=mlp_ratio, proj_bias=attn_proj_bias, gradient_checkpointing=gradient_checkpointing) for _ in range(depth) + ]) + self.final_layer = FinalLayer(hidden_size, seq_length, self.out_channels) + self.initialize_weights() + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # # Initialize (and freeze) pos_embed by sin-cos embedding: + # pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) + # self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.bias, 0) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def forward(self, x, t, y, precision_dtype=torch.float32, enable_amp=False): + """ + Forward pass of DiT. + x: (N, T, D) + t: (N,) tensor of diffusion timesteps + y: (N,) tensor of class labels + """ + x = self.x_embedder(x) + t = self.t_embedder(t) # (N, D) + if self.cond_drop_prob > 0 and self.training: + drop_mask = torch.rand(y.shape[0], device=y.device) < self.cond_drop_prob + null_cond_embed = self.null_cond_embedding[None, None, :] + y = torch.where(drop_mask[:, None, None], null_cond_embed, y) + with torch.autocast(device_type='cuda', dtype=precision_dtype, enabled=enable_amp): + for block in self.blocks: + x = block(x=x, cross_attn_cond=y, mod_cond=t) # (N, T, D) + #TODO: final layer only has timestep conditions, no sure if could be better + x = self.final_layer(x, t) # (N, T, D) + return x + + def forward_with_cfg(self, x, t, y, cfg_scale=0.0, precision_dtype=torch.float32, enable_amp=False): + combined = torch.cat([x, x], dim=0) + combined_t = torch.cat([t, t], dim=0) + y_null = self.null_cond_embedding.expand_as(y) + combined_y = torch.cat([y, y_null], dim=0) + model_out = self.forward(combined, combined_t, combined_y, precision_dtype, enable_amp) + eps = model_out + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + return half_eps + +class DiTAdditivePosEmb(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + def __init__( + self, + seq_length=2, + in_channels=4, + condition_channels=512, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + attn_proj_bias=False, + learn_sigma=True, + gradient_checkpointing=False, + ): + super().__init__() + self.gradient_checkpointing = gradient_checkpointing + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.seq_length = seq_length + self.num_heads = num_heads + + # no need to patchify as prim representation is already patch-wise + self.point_emb = PointEmbed(hidden_dim=48, dim=hidden_size) + self.x_embedder = nn.Linear(in_channels, hidden_size) + # self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + + self.blocks = nn.ModuleList([ + DiTBlock(hidden_size, condition_channels, num_heads, mlp_ratio=mlp_ratio, proj_bias=attn_proj_bias, gradient_checkpointing=gradient_checkpointing) for _ in range(depth) + ]) + self.final_layer = FinalLayer(hidden_size, seq_length, self.out_channels) + self.initialize_weights() + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # # Initialize (and freeze) pos_embed by sin-cos embedding: + # pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) + # self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.bias, 0) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def forward(self, x, t, y, precision_dtype=torch.float32, enable_amp=False): + """ + Forward pass of DiT. + x: (N, T, D) + t: (N,) tensor of diffusion timesteps + y: (N,) tensor of class labels + """ + point = x[:, :, 1:4] + point_emb = self.point_emb(point) + x = self.x_embedder(x) + point_emb + t = self.t_embedder(t) # (N, D) + with torch.autocast(device_type='cuda', dtype=precision_dtype, enabled=enable_amp): + for block in self.blocks: + x = block(x=x, cross_attn_cond=y, mod_cond=t) # (N, T, D) + #TODO: final layer only has timestep conditions, no sure if could be better + x = self.final_layer(x, t) # (N, T, D) + return x \ No newline at end of file diff --git a/models/primsdf.py b/models/primsdf.py new file mode 100644 index 0000000000000000000000000000000000000000..d6e6e5ff93e6bd752b512e11e7b1c7c55e785bdb --- /dev/null +++ b/models/primsdf.py @@ -0,0 +1,136 @@ +import torch +import trimesh +import torch.nn as nn +import torch.nn.functional as F + +import logging + +logger = logging.getLogger(__name__) + +class PrimSDF(nn.Module): + def __init__(self, mesh_obj=None, f_sdf=None, geo_fn=None, asset_list=None, num_prims=1024, dim_feat=6, prim_shape=8, init_scale=0.05, sdf2alpha_var=0.005, auto_scale_init=True, init_sampling="uniform"): + super().__init__() + self.num_prims = num_prims + # 6 channels features - [SDF, R, G, B, roughness, metallic] + self.dim_feat = dim_feat + self.prim_shape = prim_shape + self.sdf_sampled_point = None + self.auto_scale_init = auto_scale_init + self.init_sampling = init_sampling + self.sdf2alpha_var = sdf2alpha_var + + # assume the mesh is normalized to [-1, 1] cube + self.mesh_obj = mesh_obj + self.f_sdf = f_sdf + # N x (D x S^3 + 3(Global Translation) + 1(Global Scale)) + self.srt_param = nn.parameter.Parameter(torch.zeros(self.num_prims, 1 + 3)) + self.feat_param = nn.parameter.Parameter(torch.zeros(self.num_prims, self.dim_feat * (self.prim_shape ** 3))) + self.geo_start_index = 0 + self.geo_end_index = self.geo_start_index + self.prim_shape ** 3 # non-inclusive + self.tex_start_index = self.geo_end_index + self.tex_end_index = self.tex_start_index + self.prim_shape ** 3 * 3 # non-inclusive + self.mat_start_index = self.tex_end_index + self.mat_end_index = self.mat_start_index + self.prim_shape ** 3 * 2 + + # sampled_point -> local grid + # local_grid - [prim_shape^3, 3] + xx = torch.linspace(-1, 1, self.prim_shape) + # two ways to sample xyz-axis aligned local grids: 1st is ij indexing + meshx, meshy, meshz = torch.meshgrid(xx, xx, xx, indexing='ij') + local_grid = torch.stack((meshz, meshy, meshx), dim=-1).reshape(-1, 3) + self.local_grid = local_grid + # second is xy indexing, equivalent to the first one + # meshx, meshy, meshz = torch.meshgrid(xx, xx, xx, indexing='xy') + # local_grid = torch.stack((meshz, meshx, meshy), dim=-1).reshape(-1, 3) + if self.f_sdf is not None and geo_fn is not None and asset_list is not None: + self._init_param(init_scale=init_scale, geo_fn=geo_fn, asset_list=asset_list, sampling=self.init_sampling) + + @torch.no_grad() + def _init_param(self, init_scale, geo_fn, asset_list, sampling="uniform"): + pass + + def forward(self, x): + # x - [bs, 3] + bs = x.shape[0] + weights = self.prim_weight(x) + output = self.grid_sample_feat(x, weights) + preds = {} + preds['sdf'] = output[:, 0:1] + # RGB + preds['tex'] = torch.clip(output[:, 1:4], min=0.0, max=1.0) + # roughness, metallic + preds['mat'] = torch.clip(output[:, 4:6], min=0.0, max=1.0) + return preds + + def grid_sample_feat(self, x, weights): + # implementation of I_V -> trilinear grid sample of V_i + # x - [bs, 3] + # weights - [bs, n_prims] + bs = x.shape[0] + sampled_point = (x[:, None, :] - self.pos[None, ...]) / self.scale[None, ...] + mask = weights > 0 + ind_bs, ind_nprim = torch.where(weights > 0) + masked_sampled_point = sampled_point[ind_bs, ind_nprim, :].reshape(ind_nprim.shape[0], 1, 1, 1, 3) + feat4sample = self.feat[ind_nprim, :].reshape(ind_nprim.shape[0], self.dim_feat, self.prim_shape, self.prim_shape, self.prim_shape) + + sampled_feat = F.grid_sample(feat4sample, masked_sampled_point, mode='bilinear', padding_mode='zeros', align_corners=True).reshape(ind_nprim.shape[0], self.dim_feat) + weighted_sampled_feat = sampled_feat * weights[mask][:, None] + weighted_feat = torch.zeros(bs, self.dim_feat).to(x) + weighted_feat.index_add_(0, ind_bs, weighted_sampled_feat) + + # at inference time, fill in approximated SDF value for region not covered by prims + if not self.training: + # get mask for points not covered by prims + bs_mask = weights.sum(1) <= 0 + + # get nearest prim index + dist = torch.norm(x[bs_mask, None, :] - self.pos[None, ...], p=2, dim=-1) + _, min_dist_ind = dist.min(1) + nearest_prim_pos = self.pos[min_dist_ind, :] + nearest_prim_scale = self.scale[min_dist_ind, :] + + # in each nearest prim, get nearest voxel points + candidate_nearest_pts = nearest_prim_pos[:, None, :] + nearest_prim_scale[..., None] * self.local_grid.to(x)[None, :] + pts_dist = torch.norm(x[bs_mask, None, :] - candidate_nearest_pts, p=2, dim=-1) + min_dist, min_dist_pts_ind = pts_dist.min(1) + + # get the SDF value as a nearest valid SDF value + min_pts_sdf = self.feat_geo[min_dist_ind, min_dist_pts_ind] + # approximate SDF value with the same sign distance + L2 distance + approx_sdf = min_pts_sdf + min_dist * torch.sign(min_pts_sdf) + weighted_feat[bs_mask, 0:1] = approx_sdf[:, None] + return weighted_feat + + def prim_weight(self, x): + # x - [bs, 3] + weights = F.relu(1 - torch.norm((x[:, None, :] - self.pos[None, ...]) / self.scale[None, ...], p = float('inf'), dim=-1)) + # weight - [bs, N] + normalized_weights = weights / (torch.sum(weights, dim=-1, keepdim=True) + 1e-6) + return normalized_weights + + def sdf2alpha(self, sdf): + return torch.exp(-(sdf / self.sdf2alpha_var) ** 2) + + @property + def pos(self): + return self.srt_param[:, 1:4] + + @property + def scale(self): + return self.srt_param[:, 0:1] + + @property + def feat(self): + return self.feat_param + + @property + def feat_geo(self): + return self.feat_param[:, self.geo_start_index:self.geo_end_index] + + @property + def feat_tex(self): + return self.feat_param[:, self.tex_start_index:self.tex_end_index] + + @property + def feat_mat(self): + return self.feat_param[:, self.mat_start_index:self.mat_end_index] \ No newline at end of file diff --git a/models/utils.py b/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4af036508b8d86aae5a6b75eac51dc775c85831b --- /dev/null +++ b/models/utils.py @@ -0,0 +1,101 @@ +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +import numpy as np +import math +from itertools import repeat +import collections.abc +from .attention import MemEffAttention + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + return parse +to_2tuple = _ntuple(2) + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +################################################################################# +# Embedding Layers for Timesteps and Class Labels # +################################################################################# + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0., + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x \ No newline at end of file diff --git a/models/vae3d_dib.py b/models/vae3d_dib.py new file mode 100644 index 0000000000000000000000000000000000000000..4da3cc60e1aba978b63f2092b0b1c280e913e797 --- /dev/null +++ b/models/vae3d_dib.py @@ -0,0 +1,454 @@ +import numpy as np +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint + +from utils.typing import * +from .attention import MemEffAttention + +class VolumeAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + groups: int = 32, + eps: float = 1e-5, + residual: bool = True, + skip_scale: float = 1, + ): + super().__init__() + + self.residual = residual + self.skip_scale = skip_scale + + self.norm = nn.GroupNorm(num_groups=groups, num_channels=dim, eps=eps, affine=True) + self.attn = MemEffAttention(dim, num_heads, qkv_bias, proj_bias, attn_drop, proj_drop) + + def forward(self, x): + # x: [B, C, H, W, D] + B, C, H, W, D = x.shape + + res = x + x = self.norm(x) + + x = x.permute(0, 2, 3, 4, 1).reshape(B, -1, C) + x = self.attn(x) + x = x.reshape(B, H, W, D, C).permute(0, 4, 1, 2, 3).reshape(B, C, H, W, D) + + if self.residual: + x = (x + res) * self.skip_scale + + return x + +class DiagonalGaussianDistribution: + def __init__(self, parameters, deterministic=False): + # parameters: [B, 2C, ...] + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device, dtype=self.parameters.dtype) + + def sample(self): + sample = torch.randn(self.mean.shape, device=self.parameters.device, dtype=self.parameters.dtype) + x = self.mean + self.std * sample + return x + + def kl(self, other=None, dims=[1, 2, 3, 4]): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.mean(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=dims) + else: + return 0.5 * torch.mean( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=dims, + ) + + def nll(self, sample, dims=[1, 2, 3, 4]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) + + def mode(self): + return self.mean + + +class ResnetBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + resample: Literal['default', 'up', 'down'] = 'default', + groups: int = 32, + eps: float = 1e-5, + skip_scale: float = 1, # multiplied to output + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.skip_scale = skip_scale + + self.norm1 = nn.GroupNorm(num_groups=min(groups, in_channels), num_channels=in_channels, eps=eps, affine=True) + self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + self.norm2 = nn.GroupNorm(num_groups=min(groups, out_channels), num_channels=out_channels, eps=eps, affine=True) + self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + self.act = F.silu + + self.resample = None + if resample == 'up': + self.resample = partial(F.interpolate, scale_factor=2.0, mode="nearest") + elif resample == 'down': + self.resample = nn.AvgPool3d(kernel_size=2, stride=2) + + self.shortcut = nn.Identity() + if self.in_channels != self.out_channels: + self.shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=True) + + + def forward(self, x): + res = x + + x = self.norm1(x) + x = self.act(x) + + if self.resample: + res = self.resample(res) + x = self.resample(x) + + x = self.conv1(x) + x = self.norm2(x) + x = self.act(x) + x = self.conv2(x) + + x = (x + self.shortcut(res)) * self.skip_scale + + return x + +class DownBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 1, + downsample: bool = True, + skip_scale: float = 1, + gradient_checkpointing: bool = False, + ): + super().__init__() + self.gradient_checkpointing = gradient_checkpointing + + nets = [] + for i in range(num_layers): + cin = in_channels if i == 0 else out_channels + nets.append(ResnetBlock(cin, out_channels, skip_scale=skip_scale)) + self.nets = nn.ModuleList(nets) + + self.downsample = None + if downsample: + self.downsample = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=2, padding=1) + + def forward(self, x): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + def _forward(self, x): + + for net in self.nets: + x = net(x) + + if self.downsample: + x = self.downsample(x) + + return x + + +class MidBlock(nn.Module): + def __init__( + self, + in_channels: int, + num_layers: int = 1, + attention: bool = True, + attention_heads: int = 8, + skip_scale: float = 1, + gradient_checkpointing: bool = False, + ): + super().__init__() + self.gradient_checkpointing = gradient_checkpointing + + nets = [] + attns = [] + # first layer + nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale)) + # more layers + for i in range(num_layers): + nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale)) + if attention: + attns.append(VolumeAttention(in_channels, attention_heads, skip_scale=skip_scale)) + else: + attns.append(None) + self.nets = nn.ModuleList(nets) + self.attns = nn.ModuleList(attns) + + def forward(self, x): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + def _forward(self, x): + x = self.nets[0](x) + for attn, net in zip(self.attns, self.nets[1:]): + if attn: + x = attn(x) + x = net(x) + return x + + +class UpBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 1, + upsample: bool = True, + skip_scale: float = 1, + gradient_checkpointing: bool = False, + ): + super().__init__() + self.gradient_checkpointing = gradient_checkpointing + + nets = [] + for i in range(num_layers): + cin = in_channels if i == 0 else out_channels + nets.append(ResnetBlock(cin, out_channels, skip_scale=skip_scale)) + + self.nets = nn.ModuleList(nets) + + self.upsample = None + if upsample: + self.upsample = nn.ConvTranspose3d(out_channels, out_channels, kernel_size=2, stride=2) + + def forward(self, x): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + def _forward(self, x): + + for net in self.nets: + x = net(x) + + if self.upsample: + x = self.upsample(x) + + return x + + +class Encoder(nn.Module): + def __init__( + self, + in_channels: int = 1, + out_channels: int = 2 * 16, # double_z + down_channels: Tuple[int, ...] = (8, 16, 32, 64), + mid_attention: bool = True, + layers_per_block: int = 2, + skip_scale: float = np.sqrt(0.5), + gradient_checkpointing: bool = False, + ): + super().__init__() + + # input (first downsample) + self.conv_in = nn.Conv3d(in_channels, down_channels[0], kernel_size=3, stride=1, padding=1) + + # down + down_blocks = [] + cout = down_channels[0] + for i in range(len(down_channels)): + cin = cout + cout = down_channels[i] + + down_blocks.append(DownBlock( + cin, cout, + num_layers=layers_per_block, + downsample=(i != len(down_channels) - 1), # not final layer + skip_scale=skip_scale, + gradient_checkpointing=gradient_checkpointing, + )) + self.down_blocks = nn.ModuleList(down_blocks) + + # mid + self.mid_block = MidBlock(down_channels[-1], attention=mid_attention, skip_scale=skip_scale) + + # last + self.norm_out = nn.GroupNorm(num_channels=down_channels[-1], num_groups=32, eps=1e-5) + self.conv_out = nn.Conv3d(down_channels[-1], out_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + # x: [B, Cin, H, W, D] + + # first + x = self.conv_in(x) + + # down + for block in self.down_blocks: + x = block(x) + + # mid + x = self.mid_block(x) + + # last + x = self.norm_out(x) + x = F.silu(x) + x = self.conv_out(x) + + return x + + +class Decoder(nn.Module): + def __init__( + self, + in_channels: int = 16, + out_channels: int = 1, + up_channels: Tuple[int, ...] = (64, 32, 16, 8), + mid_attention: bool = True, + layers_per_block: int = 2, + skip_scale: float = np.sqrt(0.5), + gradient_checkpointing: bool = False, + ): + super().__init__() + + # first + self.conv_in = nn.Conv3d(in_channels, up_channels[0], kernel_size=3, stride=1, padding=1) + + # mid + self.mid_block = MidBlock(up_channels[0], attention=mid_attention, skip_scale=skip_scale) + + # up + up_blocks = [] + cout = up_channels[0] + for i in range(len(up_channels)): + cin = cout + cout = up_channels[i] + + up_blocks.append(UpBlock( + cin, cout, + num_layers=layers_per_block, + upsample=(i != len(up_channels) - 1), # not final layer + skip_scale=skip_scale, + gradient_checkpointing=gradient_checkpointing, + )) + self.up_blocks = nn.ModuleList(up_blocks) + + # last (upsample) + self.norm_out = nn.GroupNorm(num_channels=up_channels[-1], num_groups=min(32, up_channels[-1]), eps=1e-5) + self.conv_out = nn.ConvTranspose3d(up_channels[-1], out_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + # x: [B, Cin, H, W, D] + + # first + x = self.conv_in(x) + + # mid + x = self.mid_block(x) + + # up + for block in self.up_blocks: + x = block(x) + + # last + x = self.norm_out(x) + x = F.silu(x) + x = self.conv_out(x) + + return x + + +class VAE(nn.Module): + def __init__( + self, + in_channels: int = 1, + latent_channels: int = 16, + out_channels: int = 1, + down_channels: Tuple[int, ...] = (16, 32, 64, 128, 256), + mid_attention: bool = True, + up_channels: Tuple[int, ...] = (256, 128, 64, 32, 16), + layers_per_block: int = 2, + skip_scale: float = np.sqrt(0.5), + gradient_checkpointing: bool = False, + ): + super().__init__() + + # encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=2 * latent_channels, # double_z + down_channels=down_channels, + mid_attention=mid_attention, + layers_per_block=layers_per_block, + skip_scale=skip_scale, + gradient_checkpointing=gradient_checkpointing, + ) + + # decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_channels=up_channels, + mid_attention=mid_attention, + layers_per_block=layers_per_block, + skip_scale=skip_scale, + gradient_checkpointing=gradient_checkpointing, + ) + + # quant + self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, 1) + self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, 1) + + def encode(self, x): + x = self.encoder(x) + x = self.quant_conv(x) + posterior = DiagonalGaussianDistribution(x) + return posterior + + def decode(self, x): + x = self.post_quant_conv(x) + x = self.decoder(x) + return x + + def forward(self, x, sample=True): + # x: [B, Cin, H, W, D] + + p = self.encode(x) + + if sample: + z = p.sample() + else: + z = p.mode() + + x = self.decode(z) + + return x, p \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..ce3f057f1caebe6804ae2653b61e13f83cf46d28 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,23 @@ +torch==2.1.2 +einops +xformers +omegaconf +opencv-python +libigl +trimesh==4.2.0 +pygltflib +pymeshlab==0.2 +PyMCubes +xatlas +git+https://github.com/NVlabs/nvdiffrast/ +scikit-learn +open_clip_torch +triton==2.1.0 +rembg +gradio +tqdm +transformers==4.40.1 +diffusers==0.19.3 +ninja +imageio +imageio-ffmpeg \ No newline at end of file diff --git a/utils/mesh.py b/utils/mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..35416c0fa24e77ee7902364d722d65b2eb6c045e --- /dev/null +++ b/utils/mesh.py @@ -0,0 +1,944 @@ +import os +import cv2 +import torch +import trimesh +import numpy as np + +from .typing import * +from .op import safe_normalize, dot + +class Mesh: + """ + A torch-native trimesh class, with support for ``ply/obj/glb`` formats. + + Note: + This class only supports one mesh with a single texture image (an albedo texture and a metallic-roughness texture). + """ + def __init__( + self, + v: Optional[Tensor] = None, + f: Optional[Tensor] = None, + vn: Optional[Tensor] = None, + fn: Optional[Tensor] = None, + vt: Optional[Tensor] = None, + ft: Optional[Tensor] = None, + vc: Optional[Tensor] = None, # vertex color + albedo: Optional[Tensor] = None, + metallicRoughness: Optional[Tensor] = None, + device: Optional[torch.device] = None, + ): + """Init a mesh directly using all attributes. + + Args: + v (Optional[Tensor]): vertices, float [N, 3]. Defaults to None. + f (Optional[Tensor]): faces, int [M, 3]. Defaults to None. + vn (Optional[Tensor]): vertex normals, float [N, 3]. Defaults to None. + fn (Optional[Tensor]): faces for normals, int [M, 3]. Defaults to None. + vt (Optional[Tensor]): vertex uv coordinates, float [N, 2]. Defaults to None. + ft (Optional[Tensor]): faces for uvs, int [M, 3]. Defaults to None. + vc (Optional[Tensor]): vertex colors, float [N, 3]. Defaults to None. + albedo (Optional[Tensor]): albedo texture, float [H, W, 3], RGB format. Defaults to None. + metallicRoughness (Optional[Tensor]): metallic-roughness texture, float [H, W, 3], metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1]. Defaults to None. + device (Optional[torch.device]): torch device. Defaults to None. + """ + self.device = device + self.v = v + self.vn = vn + self.vt = vt + self.f = f + self.fn = fn + self.ft = ft + # will first see if there is vertex color to use + self.vc = vc + # only support a single albedo image + self.albedo = albedo + # pbr extension, metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1] + # ref: https://registry.khronos.org/glTF/specs/2.0/glTF-2.0.html + self.metallicRoughness = metallicRoughness + + self.ori_center = 0 + self.ori_scale = 1 + + @classmethod + def load(cls, path, resize=True, clean=False, renormal=True, retex=False, bound=0.9, front_dir='+z', **kwargs): + """load mesh from path. + + Args: + path (str): path to mesh file, supports ply, obj, glb. + clean (bool, optional): perform mesh cleaning at load (e.g., merge close vertices). Defaults to False. + resize (bool, optional): auto resize the mesh using ``bound`` into [-bound, bound]^3. Defaults to True. + renormal (bool, optional): re-calc the vertex normals. Defaults to True. + retex (bool, optional): re-calc the uv coordinates, will overwrite the existing uv coordinates. Defaults to False. + bound (float, optional): bound to resize. Defaults to 0.9. + front_dir (str, optional): front-view direction of the mesh, should be [+-][xyz][ 123]. Defaults to '+z'. + device (torch.device, optional): torch device. Defaults to None. + + Note: + a ``device`` keyword argument can be provided to specify the torch device. + If it's not provided, we will try to use ``'cuda'`` as the device if it's available. + + Returns: + Mesh: the loaded Mesh object. + """ + # obj supports face uv + if path.endswith(".obj"): + mesh = cls.load_obj(path, **kwargs) + # trimesh only supports vertex uv, but can load more formats + else: + mesh = cls.load_trimesh(path, **kwargs) + + # clean + if clean: + from .meshutils import clean_mesh + vertices = mesh.v.detach().cpu().numpy() + triangles = mesh.f.detach().cpu().numpy() + vertices, triangles = clean_mesh(vertices, triangles, remesh=False) + mesh.v = torch.from_numpy(vertices).contiguous().float().to(mesh.device) + mesh.f = torch.from_numpy(triangles).contiguous().int().to(mesh.device) + + print(f"[INFO] load mesh, v: {mesh.v.shape}, f: {mesh.f.shape}") + # auto-normalize + if resize: + mesh.auto_size(bound=bound) + # auto-fix normal + if renormal or mesh.vn is None: + mesh.auto_normal() + print(f"[INFO] load mesh, vn: {mesh.vn.shape}, fn: {mesh.fn.shape}") + # auto-fix texcoords + if retex or (mesh.albedo is not None and mesh.vt is None): + mesh.auto_uv(cache_path=path) + print(f"[INFO] load mesh, vt: {mesh.vt.shape}, ft: {mesh.ft.shape}") + + # rotate front dir to +z + if front_dir != "+z": + # axis switch + if "-z" in front_dir: + T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, -1]], device=mesh.device, dtype=torch.float32) + elif "+x" in front_dir: + T = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32) + elif "-x" in front_dir: + T = torch.tensor([[0, 0, -1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32) + elif "+y" in front_dir: + T = torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]], device=mesh.device, dtype=torch.float32) + elif "-y" in front_dir: + T = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=mesh.device, dtype=torch.float32) + else: + T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) + # rotation (how many 90 degrees) + if '1' in front_dir: + T @= torch.tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) + elif '2' in front_dir: + T @= torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) + elif '3' in front_dir: + T @= torch.tensor([[0, 1, 0], [-1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) + mesh.v @= T + mesh.vn @= T + + return mesh + + # load from obj file + @classmethod + def load_obj(cls, path, albedo_path=None, device=None): + """load an ``obj`` mesh. + + Args: + path (str): path to mesh. + albedo_path (str, optional): path to the albedo texture image, will overwrite the existing texture path if specified in mtl. Defaults to None. + device (torch.device, optional): torch device. Defaults to None. + + Note: + We will try to read `mtl` path from `obj`, else we assume the file name is the same as `obj` but with `mtl` extension. + The `usemtl` statement is ignored, and we only use the last material path in `mtl` file. + + Returns: + Mesh: the loaded Mesh object. + """ + assert os.path.splitext(path)[-1] == ".obj" + + mesh = cls() + + # device + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + mesh.device = device + + # load obj + with open(path, "r") as f: + lines = f.readlines() + + def parse_f_v(fv): + # pass in a vertex term of a face, return {v, vt, vn} (-1 if not provided) + # supported forms: + # f v1 v2 v3 + # f v1/vt1 v2/vt2 v3/vt3 + # f v1/vt1/vn1 v2/vt2/vn2 v3/vt3/vn3 + # f v1//vn1 v2//vn2 v3//vn3 + xs = [int(x) - 1 if x != "" else -1 for x in fv.split("/")] + xs.extend([-1] * (3 - len(xs))) + return xs[0], xs[1], xs[2] + + vertices, texcoords, normals = [], [], [] + faces, tfaces, nfaces = [], [], [] + mtl_path = None + + for line in lines: + split_line = line.split() + # empty line + if len(split_line) == 0: + continue + prefix = split_line[0].lower() + # mtllib + if prefix == "mtllib": + mtl_path = split_line[1] + # usemtl + elif prefix == "usemtl": + pass # ignored + # v/vn/vt + elif prefix == "v": + vertices.append([float(v) for v in split_line[1:]]) + elif prefix == "vn": + normals.append([float(v) for v in split_line[1:]]) + elif prefix == "vt": + val = [float(v) for v in split_line[1:]] + texcoords.append([val[0], 1.0 - val[1]]) + elif prefix == "f": + vs = split_line[1:] + nv = len(vs) + v0, t0, n0 = parse_f_v(vs[0]) + for i in range(nv - 2): # triangulate (assume vertices are ordered) + v1, t1, n1 = parse_f_v(vs[i + 1]) + v2, t2, n2 = parse_f_v(vs[i + 2]) + faces.append([v0, v1, v2]) + tfaces.append([t0, t1, t2]) + nfaces.append([n0, n1, n2]) + + mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device) + mesh.vt = ( + torch.tensor(texcoords, dtype=torch.float32, device=device) + if len(texcoords) > 0 + else None + ) + mesh.vn = ( + torch.tensor(normals, dtype=torch.float32, device=device) + if len(normals) > 0 + else None + ) + + mesh.f = torch.tensor(faces, dtype=torch.int32, device=device) + mesh.ft = ( + torch.tensor(tfaces, dtype=torch.int32, device=device) + if len(texcoords) > 0 + else None + ) + mesh.fn = ( + torch.tensor(nfaces, dtype=torch.int32, device=device) + if len(normals) > 0 + else None + ) + + # see if there is vertex color + use_vertex_color = False + if mesh.v.shape[1] == 6: + use_vertex_color = True + mesh.vc = mesh.v[:, 3:] + mesh.v = mesh.v[:, :3] + print(f"[INFO] load obj mesh: use vertex color: {mesh.vc.shape}") + + # try to load texture image + if not use_vertex_color: + # try to retrieve mtl file + mtl_path_candidates = [] + if mtl_path is not None: + mtl_path_candidates.append(mtl_path) + mtl_path_candidates.append(os.path.join(os.path.dirname(path), mtl_path)) + mtl_path_candidates.append(path.replace(".obj", ".mtl")) + + mtl_path = None + for candidate in mtl_path_candidates: + if os.path.exists(candidate): + mtl_path = candidate + break + + # if albedo_path is not provided, try retrieve it from mtl + metallic_path = None + roughness_path = None + if mtl_path is not None and albedo_path is None: + with open(mtl_path, "r") as f: + lines = f.readlines() + + for line in lines: + split_line = line.split() + # empty line + if len(split_line) == 0: + continue + prefix = split_line[0] + + if "map_Kd" in prefix: + # assume relative path! + albedo_path = os.path.join(os.path.dirname(path), split_line[1]) + print(f"[INFO] load obj mesh: use texture from: {albedo_path}") + elif "map_Pm" in prefix: + metallic_path = os.path.join(os.path.dirname(path), split_line[1]) + elif "map_Pr" in prefix: + roughness_path = os.path.join(os.path.dirname(path), split_line[1]) + + # still not found albedo_path, or the path doesn't exist + if albedo_path is None or not os.path.exists(albedo_path): + # init an empty texture + print(f"[INFO] load obj mesh: init empty albedo!") + # albedo = np.random.rand(1024, 1024, 3).astype(np.float32) + albedo = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5]) # default color + else: + albedo = cv2.imread(albedo_path, cv2.IMREAD_UNCHANGED) + albedo = cv2.cvtColor(albedo, cv2.COLOR_BGR2RGB) + albedo = albedo.astype(np.float32) / 255 + print(f"[INFO] load obj mesh: load texture: {albedo.shape}") + + mesh.albedo = torch.tensor(albedo, dtype=torch.float32, device=device) + + # try to load metallic and roughness + if metallic_path is not None and roughness_path is not None: + print(f"[INFO] load obj mesh: load metallicRoughness from: {metallic_path}, {roughness_path}") + metallic = cv2.imread(metallic_path, cv2.IMREAD_UNCHANGED) + metallic = metallic.astype(np.float32) / 255 + roughness = cv2.imread(roughness_path, cv2.IMREAD_UNCHANGED) + roughness = roughness.astype(np.float32) / 255 + metallicRoughness = np.stack([np.zeros_like(metallic), roughness, metallic], axis=-1) + + mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous() + + return mesh + + @classmethod + def load_trimesh(cls, path, device=None): + """load a mesh using ``trimesh.load()``. + + Can load various formats like ``glb`` and serves as a fallback. + + Note: + We will try to merge all meshes if the glb contains more than one, + but **this may cause the texture to lose**, since we only support one texture image! + + Args: + path (str): path to the mesh file. + device (torch.device, optional): torch device. Defaults to None. + + Returns: + Mesh: the loaded Mesh object. + """ + mesh = cls() + + # device + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + mesh.device = device + + # use trimesh to load ply/glb + _data = trimesh.load(path) + # always convert scene to mesh, and apply all transforms... + if isinstance(_data, trimesh.Scene): + print(f"[INFO] load trimesh: concatenating {len(_data.geometry)} meshes.") + _concat = [] + # loop the scene graph and apply transform to each mesh + scene_graph = _data.graph.to_flattened() # dict {name: {transform: 4x4 mat, geometry: str}} + for k, v in scene_graph.items(): + name = v['geometry'] + if name in _data.geometry and isinstance(_data.geometry[name], trimesh.Trimesh): + transform = v['transform'] + _concat.append(_data.geometry[name].apply_transform(transform)) + _mesh = trimesh.util.concatenate(_concat) + else: + _mesh = _data + + if _mesh.visual.kind == 'vertex': + vertex_colors = _mesh.visual.vertex_colors + vertex_colors = np.array(vertex_colors[..., :3]).astype(np.float32) / 255 + mesh.vc = torch.tensor(vertex_colors, dtype=torch.float32, device=device) + print(f"[INFO] load trimesh: use vertex color: {mesh.vc.shape}") + elif _mesh.visual.kind == 'texture': + _material = _mesh.visual.material + if isinstance(_material, trimesh.visual.material.PBRMaterial): + texture = np.array(_material.baseColorTexture).astype(np.float32) / 255 + # load metallicRoughness if present + if _material.metallicRoughnessTexture is not None: + metallicRoughness = np.array(_material.metallicRoughnessTexture).astype(np.float32) / 255 + mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous() + elif isinstance(_material, trimesh.visual.material.SimpleMaterial): + texture = np.array(_material.to_pbr().baseColorTexture).astype(np.float32) / 255 + else: + raise NotImplementedError(f"material type {type(_material)} not supported!") + mesh.albedo = torch.tensor(texture[..., :3], dtype=torch.float32, device=device).contiguous() + print(f"[INFO] load trimesh: load texture: {texture.shape}") + else: + texture = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5]) + mesh.albedo = torch.tensor(texture, dtype=torch.float32, device=device) + print(f"[INFO] load trimesh: failed to load texture.") + + vertices = _mesh.vertices + + try: + texcoords = _mesh.visual.uv + texcoords[:, 1] = 1 - texcoords[:, 1] + except Exception as e: + texcoords = None + + try: + normals = _mesh.vertex_normals + except Exception as e: + normals = None + + # trimesh only support vertex uv... + faces = tfaces = nfaces = _mesh.faces + + mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device) + mesh.vt = ( + torch.tensor(texcoords, dtype=torch.float32, device=device) + if texcoords is not None + else None + ) + mesh.vn = ( + torch.tensor(normals, dtype=torch.float32, device=device) + if normals is not None + else None + ) + + mesh.f = torch.tensor(faces, dtype=torch.int32, device=device) + mesh.ft = ( + torch.tensor(tfaces, dtype=torch.int32, device=device) + if texcoords is not None + else None + ) + mesh.fn = ( + torch.tensor(nfaces, dtype=torch.int32, device=device) + if normals is not None + else None + ) + + return mesh + + @classmethod + def parse_trimesh_data(cls, raw_data: trimesh.Trimesh, device=None): + mesh = cls() + mesh.device = device + _mesh = raw_data + + if _mesh.visual.kind == 'vertex': + vertex_colors = _mesh.visual.vertex_colors + vertex_colors = np.array(vertex_colors[..., :3]).astype(np.float32) / 255 + mesh.vc = torch.tensor(vertex_colors, dtype=torch.float32, device=device) + print(f"[INFO] load trimesh: use vertex color: {mesh.vc.shape}") + elif _mesh.visual.kind == 'texture': + _material = _mesh.visual.material + if isinstance(_material, trimesh.visual.material.PBRMaterial): + if _material.baseColorTexture is not None: + texture = np.array(_material.baseColorTexture).astype(np.float32) / 255 + else: + # if there is no texture, init a uniform white texture by default + # TODO: support alpha blending mode when texture is RGBA + texture = np.ones((64, 64, 3), dtype=np.float32) + if _material.baseColorFactor is not None: + if isinstance(_material.baseColorFactor, np.ndarray): + texture = texture * (_material.baseColorFactor[:3] / 255) + # load metallicRoughness if present + if _material.metallicRoughnessTexture is not None: + metallicRoughness = np.array(_material.metallicRoughnessTexture).astype(np.float32) / 255 + # there could be metallicRoughness map with a single channel + if len(metallicRoughness.shape) == 2: + metallicRoughness = metallicRoughness[..., None].repeat(3, axis=-1) + else: + # init metallicRoughness if there is no predefined one + metallicRoughness = np.ones_like(texture, dtype=np.float32) + # pbr extension, metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1] + # there could be metallicRoughness map with a single channel + if len(metallicRoughness.shape) == 2: + metallicRoughness = metallicRoughness[..., None].repeat(3, axis=-1) + # we only apply metallicFactor and roughnessFactor to the asset without pbr maps + if _material.metallicFactor is not None: + metallicRoughness[..., 2] *= _material.metallicFactor + if _material.roughnessFactor is not None: + metallicRoughness[..., 1] *= _material.roughnessFactor + mesh.metallicRoughness = torch.tensor(metallicRoughness[..., :3], dtype=torch.float32, device=device).contiguous() + elif isinstance(_material, trimesh.visual.material.SimpleMaterial): + texture = np.array(_material.to_pbr().baseColorTexture).astype(np.float32) / 255 + else: + raise NotImplementedError(f"material type {type(_material)} not supported!") + # there could be texture map with a single channel + if len(texture.shape) == 2: + texture = texture[..., None].repeat(3, axis=-1) + mesh.albedo = torch.tensor(texture[..., :3], dtype=torch.float32, device=device).contiguous() + # print(f"[INFO] load trimesh: load texture: {texture.shape}") + else: + texture = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5]) + mesh.albedo = torch.tensor(texture, dtype=torch.float32, device=device) + print(f"[INFO] load trimesh: failed to load texture.") + + vertices = _mesh.vertices + + try: + texcoords = _mesh.visual.uv + # deal with repeated wrapping of texture which leads to uv coord larger than 1 + texcoords = texcoords - np.floor(texcoords) + texcoords[:, 1] = 1 - texcoords[:, 1] + except Exception as e: + # for textureless mesh, we map all texture coords to the first uv element + texcoords = torch.zeros(vertices.shape[0], 2).to(mesh.albedo) + + try: + normals = _mesh.vertex_normals + except Exception as e: + normals = None + + # trimesh only support vertex uv... + faces = tfaces = nfaces = _mesh.faces + + mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device) + mesh.vt = ( + torch.tensor(texcoords, dtype=torch.float32, device=device) + if texcoords is not None + else None + ) + mesh.vn = ( + torch.tensor(normals, dtype=torch.float32, device=device) + if normals is not None + else None + ) + + mesh.f = torch.tensor(faces, dtype=torch.int32, device=device) + mesh.ft = ( + torch.tensor(tfaces, dtype=torch.int32, device=device) + if texcoords is not None + else None + ) + mesh.fn = ( + torch.tensor(nfaces, dtype=torch.int32, device=device) + if normals is not None + else None + ) + + return mesh + + # sample surface (using trimesh) + def sample_surface(self, count: int): + """sample points on the surface of the mesh. + + Args: + count (int): number of points to sample. + + Returns: + torch.Tensor: the sampled points, float [count, 3]. + """ + _mesh = trimesh.Trimesh(vertices=self.v.detach().cpu().numpy(), faces=self.f.detach().cpu().numpy()) + points, face_idx = trimesh.sample.sample_surface(_mesh, count) + points = torch.from_numpy(points).float().to(self.device) + return points + + # aabb + def aabb(self): + """get the axis-aligned bounding box of the mesh. + + Returns: + Tuple[torch.Tensor]: the min xyz and max xyz of the mesh. + """ + return torch.min(self.v, dim=0).values, torch.max(self.v, dim=0).values + + # unit size + @torch.no_grad() + def auto_size(self, bound=0.9): + """auto resize the mesh. + + Args: + bound (float, optional): resizing into ``[-bound, bound]^3``. Defaults to 0.9. + """ + vmin, vmax = self.aabb() + self.ori_center = (vmax + vmin) / 2 + self.ori_scale = 2 * bound / torch.max(vmax - vmin).item() + self.v = (self.v - self.ori_center) * self.ori_scale + + def auto_normal(self): + """auto calculate the vertex normals. + """ + i0, i1, i2 = self.f[:, 0].long(), self.f[:, 1].long(), self.f[:, 2].long() + v0, v1, v2 = self.v[i0, :], self.v[i1, :], self.v[i2, :] + + face_normals = torch.cross(v1 - v0, v2 - v0) + + # Splat face normals to vertices + vn = torch.zeros_like(self.v) + vn.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) + vn.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) + vn.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) + + # Normalize, replace zero (degenerated) normals with some default value + vn = torch.where( + dot(vn, vn) > 1e-20, + vn, + torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device), + ) + vn = safe_normalize(vn) + + self.vn = vn + self.fn = self.f + + def auto_uv(self, cache_path=None, vmap=True): + """auto calculate the uv coordinates. + + Args: + cache_path (str, optional): path to save/load the uv cache as a npz file, this can avoid calculating uv every time when loading the same mesh, which is time-consuming. Defaults to None. + vmap (bool, optional): remap vertices based on uv coordinates, so each v correspond to a unique vt (necessary for formats like gltf). + Usually this will duplicate the vertices on the edge of uv atlas. Defaults to True. + """ + # try to load cache + if cache_path is not None: + cache_path = os.path.splitext(cache_path)[0] + "_uv.npz" + if cache_path is not None and os.path.exists(cache_path): + data = np.load(cache_path) + vt_np, ft_np, vmapping = data["vt"], data["ft"], data["vmapping"] + else: + import xatlas + + v_np = self.v.detach().cpu().numpy() + f_np = self.f.detach().int().cpu().numpy() + atlas = xatlas.Atlas() + atlas.add_mesh(v_np, f_np) + chart_options = xatlas.ChartOptions() + # chart_options.max_iterations = 4 + atlas.generate(chart_options=chart_options) + vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2] + + # save to cache + if cache_path is not None: + np.savez(cache_path, vt=vt_np, ft=ft_np, vmapping=vmapping) + + vt = torch.from_numpy(vt_np.astype(np.float32)).to(self.device) + ft = torch.from_numpy(ft_np.astype(np.int32)).to(self.device) + self.vt = vt + self.ft = ft + + if vmap: + vmapping = torch.from_numpy(vmapping.astype(np.int64)).long().to(self.device) + self.align_v_to_vt(vmapping) + + def align_v_to_vt(self, vmapping=None): + """ remap v/f and vn/fn to vt/ft. + + Args: + vmapping (np.ndarray, optional): the mapping relationship from f to ft. Defaults to None. + """ + if vmapping is None: + ft = self.ft.view(-1).long() + f = self.f.view(-1).long() + vmapping = torch.zeros(self.vt.shape[0], dtype=torch.long, device=self.device) + vmapping[ft] = f # scatter, randomly choose one if index is not unique + + self.v = self.v[vmapping] + self.f = self.ft + + if self.vn is not None: + self.vn = self.vn[vmapping] + self.fn = self.ft + + def to(self, device): + """move all tensor attributes to device. + + Args: + device (torch.device): target device. + + Returns: + Mesh: self. + """ + self.device = device + for name in ["v", "f", "vn", "fn", "vt", "ft", "albedo", "vc", "metallicRoughness"]: + tensor = getattr(self, name) + if tensor is not None: + setattr(self, name, tensor.to(device)) + return self + + def write(self, path): + """write the mesh to a path. + + Args: + path (str): path to write, supports ply, obj and glb. + """ + if path.endswith(".ply"): + self.write_ply(path) + elif path.endswith(".obj"): + self.write_obj(path) + elif path.endswith(".glb") or path.endswith(".gltf"): + self.write_glb(path) + else: + raise NotImplementedError(f"format {path} not supported!") + + def write_ply(self, path): + """write the mesh in ply format. Only for geometry! + + Args: + path (str): path to write. + """ + + if self.albedo is not None: + print(f'[WARN] ply format does not support exporting texture, will ignore!') + + v_np = self.v.detach().cpu().numpy() + f_np = self.f.detach().cpu().numpy() + + _mesh = trimesh.Trimesh(vertices=v_np, faces=f_np) + _mesh.export(path) + + + def write_glb(self, path): + """write the mesh in glb/gltf format. + This will create a scene with a single mesh. + + Args: + path (str): path to write. + """ + + # assert self.v.shape[0] == self.vn.shape[0] and self.v.shape[0] == self.vt.shape[0] + if self.vt is not None and self.v.shape[0] != self.vt.shape[0]: + self.align_v_to_vt() + + import pygltflib + + f_np = self.f.detach().cpu().numpy().astype(np.uint32) + f_np_blob = f_np.flatten().tobytes() + + v_np = self.v.detach().cpu().numpy().astype(np.float32) + v_np_blob = v_np.tobytes() + + blob = f_np_blob + v_np_blob + byteOffset = len(blob) + + # base mesh + gltf = pygltflib.GLTF2( + scene=0, + scenes=[pygltflib.Scene(nodes=[0])], + nodes=[pygltflib.Node(mesh=0)], + meshes=[pygltflib.Mesh(primitives=[pygltflib.Primitive( + # indices to accessors (0 is triangles) + attributes=pygltflib.Attributes( + POSITION=1, + ), + indices=0, + )])], + buffers=[ + pygltflib.Buffer(byteLength=len(f_np_blob) + len(v_np_blob)) + ], + # buffer view (based on dtype) + bufferViews=[ + # triangles; as flatten (element) array + pygltflib.BufferView( + buffer=0, + byteLength=len(f_np_blob), + target=pygltflib.ELEMENT_ARRAY_BUFFER, # GL_ELEMENT_ARRAY_BUFFER (34963) + ), + # positions; as vec3 array + pygltflib.BufferView( + buffer=0, + byteOffset=len(f_np_blob), + byteLength=len(v_np_blob), + byteStride=12, # vec3 + target=pygltflib.ARRAY_BUFFER, # GL_ARRAY_BUFFER (34962) + ), + ], + accessors=[ + # 0 = triangles + pygltflib.Accessor( + bufferView=0, + componentType=pygltflib.UNSIGNED_INT, # GL_UNSIGNED_INT (5125) + count=f_np.size, + type=pygltflib.SCALAR, + max=[int(f_np.max())], + min=[int(f_np.min())], + ), + # 1 = positions + pygltflib.Accessor( + bufferView=1, + componentType=pygltflib.FLOAT, # GL_FLOAT (5126) + count=len(v_np), + type=pygltflib.VEC3, + max=v_np.max(axis=0).tolist(), + min=v_np.min(axis=0).tolist(), + ), + ], + ) + + # append texture info + if self.vt is not None: + + vt_np = self.vt.detach().cpu().numpy().astype(np.float32) + vt_np_blob = vt_np.tobytes() + + albedo = self.albedo.detach().cpu().numpy() + albedo = (albedo * 255).astype(np.uint8) + albedo = cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR) + albedo_blob = cv2.imencode('.png', albedo)[1].tobytes() + + # update primitive + gltf.meshes[0].primitives[0].attributes.TEXCOORD_0 = 2 + gltf.meshes[0].primitives[0].material = 0 + + # update materials + gltf.materials.append(pygltflib.Material( + pbrMetallicRoughness=pygltflib.PbrMetallicRoughness( + baseColorTexture=pygltflib.TextureInfo(index=0, texCoord=0), + metallicFactor=0.0, + roughnessFactor=1.0, + ), + alphaMode=pygltflib.OPAQUE, + alphaCutoff=None, + doubleSided=True, + )) + + gltf.textures.append(pygltflib.Texture(sampler=0, source=0)) + gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT)) + gltf.images.append(pygltflib.Image(bufferView=3, mimeType="image/png")) + + # update buffers + gltf.bufferViews.append( + # index = 2, texcoords; as vec2 array + pygltflib.BufferView( + buffer=0, + byteOffset=byteOffset, + byteLength=len(vt_np_blob), + byteStride=8, # vec2 + target=pygltflib.ARRAY_BUFFER, + ) + ) + + gltf.accessors.append( + # 2 = texcoords + pygltflib.Accessor( + bufferView=2, + componentType=pygltflib.FLOAT, + count=len(vt_np), + type=pygltflib.VEC2, + max=vt_np.max(axis=0).tolist(), + min=vt_np.min(axis=0).tolist(), + ) + ) + + blob += vt_np_blob + byteOffset += len(vt_np_blob) + + gltf.bufferViews.append( + # index = 3, albedo texture; as none target + pygltflib.BufferView( + buffer=0, + byteOffset=byteOffset, + byteLength=len(albedo_blob), + ) + ) + + blob += albedo_blob + byteOffset += len(albedo_blob) + + gltf.buffers[0].byteLength = byteOffset + + # append metllic roughness + if self.metallicRoughness is not None: + metallicRoughness = self.metallicRoughness.detach().cpu().numpy() + metallicRoughness = (metallicRoughness * 255).astype(np.uint8) + metallicRoughness = cv2.cvtColor(metallicRoughness, cv2.COLOR_RGB2BGR) + metallicRoughness_blob = cv2.imencode('.png', metallicRoughness)[1].tobytes() + + # update texture definition + gltf.materials[0].pbrMetallicRoughness.metallicFactor = 1.0 + gltf.materials[0].pbrMetallicRoughness.roughnessFactor = 1.0 + gltf.materials[0].pbrMetallicRoughness.metallicRoughnessTexture = pygltflib.TextureInfo(index=1, texCoord=0) + + gltf.textures.append(pygltflib.Texture(sampler=1, source=1)) + gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT)) + gltf.images.append(pygltflib.Image(bufferView=4, mimeType="image/png")) + + # update buffers + gltf.bufferViews.append( + # index = 4, metallicRoughness texture; as none target + pygltflib.BufferView( + buffer=0, + byteOffset=byteOffset, + byteLength=len(metallicRoughness_blob), + ) + ) + + blob += metallicRoughness_blob + byteOffset += len(metallicRoughness_blob) + + gltf.buffers[0].byteLength = byteOffset + + + # set actual data + gltf.set_binary_blob(blob) + + # glb = b"".join(gltf.save_to_bytes()) + gltf.save(path) + + + def write_obj(self, path): + """write the mesh in obj format. Will also write the texture and mtl files. + + Args: + path (str): path to write. + """ + + mtl_path = path.replace(".obj", ".mtl") + albedo_path = path.replace(".obj", "_albedo.png") + metallic_path = path.replace(".obj", "_metallic.png") + roughness_path = path.replace(".obj", "_roughness.png") + + v_np = self.v.detach().cpu().numpy() + vt_np = self.vt.detach().cpu().numpy() if self.vt is not None else None + vn_np = self.vn.detach().cpu().numpy() if self.vn is not None else None + f_np = self.f.detach().cpu().numpy() + ft_np = self.ft.detach().cpu().numpy() if self.ft is not None else None + fn_np = self.fn.detach().cpu().numpy() if self.fn is not None else None + + with open(path, "w") as fp: + fp.write(f"mtllib {os.path.basename(mtl_path)} \n") + + for v in v_np: + fp.write(f"v {v[0]} {v[1]} {v[2]} \n") + + if vt_np is not None: + for v in vt_np: + fp.write(f"vt {v[0]} {1 - v[1]} \n") + + if vn_np is not None: + for v in vn_np: + fp.write(f"vn {v[0]} {v[1]} {v[2]} \n") + + fp.write(f"usemtl defaultMat \n") + for i in range(len(f_np)): + fp.write( + f'f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1 if ft_np is not None else ""}/{fn_np[i, 0] + 1 if fn_np is not None else ""} \ + {f_np[i, 1] + 1}/{ft_np[i, 1] + 1 if ft_np is not None else ""}/{fn_np[i, 1] + 1 if fn_np is not None else ""} \ + {f_np[i, 2] + 1}/{ft_np[i, 2] + 1 if ft_np is not None else ""}/{fn_np[i, 2] + 1 if fn_np is not None else ""} \n' + ) + + with open(mtl_path, "w") as fp: + fp.write(f"newmtl defaultMat \n") + fp.write(f"Ka 1 1 1 \n") + fp.write(f"Kd 1 1 1 \n") + fp.write(f"Ks 0 0 0 \n") + fp.write(f"Tr 1 \n") + fp.write(f"illum 1 \n") + fp.write(f"Ns 0 \n") + if self.albedo is not None: + fp.write(f"map_Kd {os.path.basename(albedo_path)} \n") + if self.metallicRoughness is not None: + # ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering + fp.write(f"map_Pm {os.path.basename(metallic_path)} \n") + fp.write(f"map_Pr {os.path.basename(roughness_path)} \n") + + if self.albedo is not None: + albedo = self.albedo.detach().cpu().numpy() + albedo = (albedo * 255).astype(np.uint8) + cv2.imwrite(albedo_path, cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR)) + + if self.metallicRoughness is not None: + metallicRoughness = self.metallicRoughness.detach().cpu().numpy() + metallicRoughness = (metallicRoughness * 255).astype(np.uint8) + cv2.imwrite(metallic_path, metallicRoughness[..., 2]) + cv2.imwrite(roughness_path, metallicRoughness[..., 1]) + diff --git a/utils/meshutils.py b/utils/meshutils.py new file mode 100644 index 0000000000000000000000000000000000000000..f95da313c1e4a0f54ef1a82933bca2508bcbf7c9 --- /dev/null +++ b/utils/meshutils.py @@ -0,0 +1,233 @@ +import torch +import numpy as np +import pymeshlab as pml +from importlib.metadata import version + +PML_VER = version('pymeshlab') + +# the code assumes the latest 2023.12 version, but we can patch older versions +if PML_VER.startswith('0.2'): + # monkey patch for 0.2 (only the used functions in this file!) + pml.MeshSet.meshing_decimation_quadric_edge_collapse = pml.MeshSet.simplification_quadric_edge_collapse_decimation + pml.MeshSet.meshing_isotropic_explicit_remeshing = pml.MeshSet.remeshing_isotropic_explicit_remeshing + pml.MeshSet.meshing_remove_unreferenced_vertices = pml.MeshSet.remove_unreferenced_vertices + pml.MeshSet.meshing_merge_close_vertices = pml.MeshSet.merge_close_vertices + pml.MeshSet.meshing_remove_duplicate_faces = pml.MeshSet.remove_duplicate_faces + pml.MeshSet.meshing_remove_null_faces = pml.MeshSet.remove_zero_area_faces + pml.MeshSet.meshing_remove_connected_component_by_diameter = pml.MeshSet.remove_isolated_pieces_wrt_diameter + pml.MeshSet.meshing_remove_connected_component_by_face_number = pml.MeshSet.remove_isolated_pieces_wrt_face_num + pml.MeshSet.meshing_repair_non_manifold_edges = pml.MeshSet.repair_non_manifold_edges_by_removing_faces + pml.MeshSet.meshing_repair_non_manifold_vertices = pml.MeshSet.repair_non_manifold_vertices_by_splitting + pml.PercentageValue = pml.Percentage + pml.PureValue = float +elif PML_VER.startswith('2022.2'): + # monkey patch for 2022.2 + pml.PercentageValue = pml.Percentage + pml.PureValue = pml.AbsoluteValue + +def rotation_matrix(axis, angle_deg): + angle_rad = np.radians(angle_deg) + if axis == 'x': + return np.array([[1, 0, 0], + [0, np.cos(angle_rad), -np.sin(angle_rad)], + [0, np.sin(angle_rad), np.cos(angle_rad)]]).astype(np.float32) + elif axis == 'y': + return np.array([[np.cos(angle_rad), 0, np.sin(angle_rad)], + [0, 1, 0], + [-np.sin(angle_rad), 0, np.cos(angle_rad)]]).astype(np.float32) + elif axis == 'z': + return np.array([[np.cos(angle_rad), -np.sin(angle_rad), 0], + [np.sin(angle_rad), np.cos(angle_rad), 0], + [0, 0, 1]]).astype(np.float32) + else: + raise ValueError("Axis must be 'x', 'y', or 'z'") + +def scale_to_unit_sphere(points): + max_xyz, _ = points.max(0) + min_xyz, _ = points.min(0) + bb_centroid = (max_xyz + min_xyz) / 2. + zero_mean_points = points - bb_centroid + dist = np.linalg.norm(points, axis=1) + normalized_points = zero_mean_points / np.max(dist) + return normalized_points + +def scale_to_unit_cube(points): + max_xyz, _ = points.max(0) + min_xyz, _ = points.min(0) + bb_centroid = (max_xyz + min_xyz) / 2. + global_scale_max = (max_xyz - min_xyz).max() + zero_mean_points = points - bb_centroid + normalized_points = zero_mean_points * (1.8 / global_scale_max) + return normalized_points + +def decimate_mesh( + verts, faces, target=5e4, backend="pymeshlab", remesh=False, optimalplacement=True +): + """ perform mesh decimation. + + Args:pml + verts (np.ndarray): mesh vertices, float [N, 3] + faces (np.ndarray): mesh faces, int [M, 3] + target (int): targeted number of faces + backend (str, optional): algorithm backend, can be "pymeshlab" or "pyfqmr". Defaults to "pymeshlab". + remesh (bool, optional): whether to remesh after decimation. Defaults to False. + optimalplacement (bool, optional): For flat mesh, use False to prevent spikes. Defaults to True. + + Returns: + Tuple[np.ndarray]: vertices and faces after decimation. + """ + + _ori_vert_shape = verts.shape + _ori_face_shape = faces.shape + + if backend == "pyfqmr": + import pyfqmr + + solver = pyfqmr.Simplify() + solver.setMesh(verts, faces) + solver.simplify_mesh(target_count=target, preserve_border=False, verbose=False) + verts, faces, normals = solver.getMesh() + else: + m = pml.Mesh(verts, faces) + ms = pml.MeshSet() + ms.add_mesh(m, "mesh") # will copy! + + # filters + # ms.meshing_decimation_clustering(threshold=pml.PercentageValue(1)) + ms.meshing_decimation_quadric_edge_collapse( + targetfacenum=int(target), optimalplacement=optimalplacement + ) + + if remesh: + # ms.apply_coord_taubin_smoothing() + ms.meshing_isotropic_explicit_remeshing( + iterations=3, targetlen=pml.PercentageValue(1) + ) + + # extract mesh + m = ms.current_mesh() + m.compact() + verts = m.vertex_matrix() + faces = m.face_matrix() + + print(f"[INFO] mesh decimation: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}") + + return verts, faces + + +def clean_mesh( + verts, + faces, + v_pct=1, + min_f=64, + min_d=20, + repair=True, + remesh=True, + remesh_size=0.01, + remesh_iters=3, +): + """ perform mesh cleaning, including floater removal, non manifold repair, and remeshing. + + Args: + verts (np.ndarray): mesh vertices, float [N, 3] + faces (np.ndarray): mesh faces, int [M, 3] + v_pct (int, optional): percentage threshold to merge close vertices. Defaults to 1. + min_f (int, optional): maximal number of faces for isolated component to remove. Defaults to 64. + min_d (int, optional): maximal diameter percentage of isolated component to remove. Defaults to 20. + repair (bool, optional): whether to repair non-manifold faces (cannot gurantee). Defaults to True. + remesh (bool, optional): whether to perform a remeshing after all cleaning. Defaults to True. + remesh_size (float, optional): the targeted edge length for remeshing. Defaults to 0.01. + remesh_iters (int, optional): the iterations of remeshing. Defaults to 3. + + Returns: + Tuple[np.ndarray]: vertices and faces after decimation. + """ + # verts: [N, 3] + # faces: [N, 3] + + _ori_vert_shape = verts.shape + _ori_face_shape = faces.shape + + m = pml.Mesh(verts, faces) + ms = pml.MeshSet() + ms.add_mesh(m, "mesh") # will copy! + + # filters + ms.meshing_remove_unreferenced_vertices() # verts not refed by any faces + + if v_pct > 0: + ms.meshing_merge_close_vertices( + threshold=pml.PercentageValue(v_pct) + ) # 1/10000 of bounding box diagonal + + ms.meshing_remove_duplicate_faces() # faces defined by the same verts + ms.meshing_remove_null_faces() # faces with area == 0 + + if min_d > 0: + ms.meshing_remove_connected_component_by_diameter( + mincomponentdiag=pml.PercentageValue(min_d) + ) + + if min_f > 0: + ms.meshing_remove_connected_component_by_face_number(mincomponentsize=min_f) + + if repair: + # ms.meshing_remove_t_vertices(method=0, threshold=40, repeat=True) + ms.meshing_repair_non_manifold_edges(method=0) + ms.meshing_repair_non_manifold_vertices(vertdispratio=0) + + if remesh: + # ms.apply_coord_taubin_smoothing() + ms.meshing_isotropic_explicit_remeshing( + iterations=remesh_iters, targetlen=pml.PureValue(remesh_size) + ) + + # extract mesh + m = ms.current_mesh() + m.compact() + verts = m.vertex_matrix() + faces = m.face_matrix() + + print(f"[INFO] mesh cleaning: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}") + + return verts, faces + +@torch.no_grad() +def compute_edge_to_face_mapping(faces): + """ compute edge to face mapping. + + Args: + faces (torch.Tensor): mesh faces, int [M, 3] + + Returns: + torch.Tensor: indices to faces for each edge, long, [N, 2] + """ + # Get unique edges + # Create all edges, packed by triangle + all_edges = torch.cat(( + torch.stack((faces[:, 0], faces[:, 1]), dim=-1), + torch.stack((faces[:, 1], faces[:, 2]), dim=-1), + torch.stack((faces[:, 2], faces[:, 0]), dim=-1), + ), dim=-1).view(-1, 2) + + # Swap edge order so min index is always first + order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1) + sorted_edges = torch.cat(( + torch.gather(all_edges, 1, order), + torch.gather(all_edges, 1, 1 - order) + ), dim=-1) + + # Elliminate duplicates and return inverse mapping + unique_edges, idx_map = torch.unique(sorted_edges, dim=0, return_inverse=True) + + tris = torch.arange(faces.shape[0]).repeat_interleave(3).cuda() + + tris_per_edge = torch.zeros((unique_edges.shape[0], 2), dtype=torch.int64).cuda() + + # Compute edge to face table + mask0 = order[:,0] == 0 + mask1 = order[:,0] == 1 + tris_per_edge[idx_map[mask0], 0] = tris[mask0] + tris_per_edge[idx_map[mask1], 1] = tris[mask1] + + return tris_per_edge \ No newline at end of file diff --git a/utils/op.py b/utils/op.py new file mode 100644 index 0000000000000000000000000000000000000000..a7a33fc498b1c99e445fc886ca70389d0a5c1858 --- /dev/null +++ b/utils/op.py @@ -0,0 +1,47 @@ +import torch +import numpy as np +from .typing import * + +# torch / numpy math utils +def dot(x: Union[Tensor, ndarray], y: Union[Tensor, ndarray]) -> Union[Tensor, ndarray]: + """dot product (along the last dim). + + Args: + x (Union[Tensor, ndarray]): x, [..., C] + y (Union[Tensor, ndarray]): y, [..., C] + + Returns: + Union[Tensor, ndarray]: x dot y, [..., 1] + """ + if isinstance(x, np.ndarray): + return np.sum(x * y, -1, keepdims=True) + else: + return torch.sum(x * y, -1, keepdim=True) + +def length(x: Union[Tensor, ndarray], eps=1e-20) -> Union[Tensor, ndarray]: + """length of an array (along the last dim). + + Args: + x (Union[Tensor, ndarray]): x, [..., C] + eps (float, optional): eps. Defaults to 1e-20. + + Returns: + Union[Tensor, ndarray]: length, [..., 1] + """ + if isinstance(x, np.ndarray): + return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps)) + else: + return torch.sqrt(torch.clamp(dot(x, x), min=eps)) + +def safe_normalize(x: Union[Tensor, ndarray], eps=1e-20) -> Union[Tensor, ndarray]: + """normalize an array (along the last dim). + + Args: + x (Union[Tensor, ndarray]): x, [..., C] + eps (float, optional): eps. Defaults to 1e-20. + + Returns: + Union[Tensor, ndarray]: normalized x, [..., C] + """ + + return x / length(x, eps) \ No newline at end of file diff --git a/utils/typing.py b/utils/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..2ac09f4ef0ac99d85353d4945eda5ed74db1c631 --- /dev/null +++ b/utils/typing.py @@ -0,0 +1,4 @@ +# ref: https://mypy.readthedocs.io/en/stable/cheat_sheet_py3.html +from typing import Sequence, List, Tuple, Dict, Any, Optional, Union, Literal, Callable +from torch import Tensor +from numpy import ndarray \ No newline at end of file