diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..82a7236bf03797a01d3ad6e6a60cdd8cdf6d0c1a
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,4 @@
+__pycache__
+build
+*.so
+runs
\ No newline at end of file
diff --git a/README.md b/README.md
index 587f35499d53dab50cbb593a5974f4643e9d558d..14234d8508d048efcfff1a305dae7ce7749aced2 100644
--- a/README.md
+++ b/README.md
@@ -1,10 +1,11 @@
---
-title: 3DTopia XL
+title: 3DTopia-XL
emoji: 🌖
colorFrom: green
colorTo: pink
sdk: gradio
sdk_version: 4.41.0
+python_version: 3.9
app_file: app.py
pinned: false
---
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..d509be562a20f610f4250deab89b5f53c350b761
--- /dev/null
+++ b/app.py
@@ -0,0 +1,209 @@
+import os
+import imageio
+import numpy as np
+
+os.system("bash install.sh")
+
+from omegaconf import OmegaConf
+import tqdm
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.transforms.functional as TF
+import rembg
+import gradio as gr
+from dva.io import load_from_config
+from dva.ray_marcher import RayMarcher
+from dva.visualize import visualize_primvolume, visualize_video_primvolume
+from inference import remove_background, resize_foreground, extract_texmesh
+from models.diffusion import create_diffusion
+from huggingface_hub import hf_hub_download
+ckpt_path = hf_hub_download(repo_id="frozenburning/3DTopia-XL", filename="model_sview_dit_fp16.pt")
+vae_ckpt_path = hf_hub_download(repo_id="frozenburning/3DTopia-XL", filename="model_vae_fp16.pt")
+
+GRADIO_PRIM_VIDEO_PATH = 'prim.mp4'
+GRADIO_RGB_VIDEO_PATH = 'rgb.mp4'
+GRADIO_MAT_VIDEO_PATH = 'mat.mp4'
+GRADIO_GLB_PATH = 'pbr_mesh.glb'
+CONFIG_PATH = "./configs/inference_dit.yml"
+
+config = OmegaConf.load(CONFIG_PATH)
+config.checkpoint_path = ckpt_path
+config.model.vae_checkpoint_path = vae_ckpt_path
+# model
+model = load_from_config(config.model.generator)
+state_dict = torch.load(config.checkpoint_path, map_location='cpu')
+model.load_state_dict(state_dict['ema'])
+vae = load_from_config(config.model.vae)
+vae_state_dict = torch.load(config.model.vae_checkpoint_path, map_location='cpu')
+vae.load_state_dict(vae_state_dict['model_state_dict'])
+conditioner = load_from_config(config.model.conditioner)
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+vae = vae.to(device)
+conditioner = conditioner.to(device)
+model = model.to(device)
+model.eval()
+
+amp = True
+precision_dtype = torch.float16
+
+rm = RayMarcher(
+ config.image_height,
+ config.image_width,
+ **config.rm,
+).to(device)
+
+perchannel_norm = False
+if "latent_mean" in config.model:
+ latent_mean = torch.Tensor(config.model.latent_mean)[None, None, :].to(device)
+ latent_std = torch.Tensor(config.model.latent_std)[None, None, :].to(device)
+ assert latent_mean.shape[-1] == config.model.generator.in_channels
+ perchannel_norm = True
+
+config.diffusion.pop("timestep_respacing")
+config.model.pop("vae")
+config.model.pop("vae_checkpoint_path")
+config.model.pop("conditioner")
+config.model.pop("generator")
+config.model.pop("latent_nf")
+config.model.pop("latent_mean")
+config.model.pop("latent_std")
+model_primx = load_from_config(config.model)
+# load rembg
+rembg_session = rembg.new_session()
+
+# process function
+def process(input_image, input_num_steps=25, input_seed=42, input_cfg=6.0):
+ # seed
+ torch.manual_seed(input_seed)
+
+ os.makedirs(config.output_dir, exist_ok=True)
+ output_rgb_video_path = os.path.join(config.output_dir, GRADIO_RGB_VIDEO_PATH)
+ output_prim_video_path = os.path.join(config.output_dir, GRADIO_PRIM_VIDEO_PATH)
+ output_mat_video_path = os.path.join(config.output_dir, GRADIO_MAT_VIDEO_PATH)
+ output_glb_path = os.path.join(config.output_dir, GRADIO_GLB_PATH)
+
+ diffusion = create_diffusion(timestep_respacing=respacing, **config.diffusion)
+ sample_fn = diffusion.ddim_sample_loop_progressive
+ fwd_fn = model.forward_with_cfg
+
+ # text-conditioned
+ if input_image is None:
+ raise NotImplementedError
+ # image-conditioned (may also input text, but no text usually works too)
+ else:
+ input_image = remove_background(input_image, rembg_session)
+ input_image = resize_foreground(input_image, 0.85)
+ raw_image = np.array(input_image)
+ mask = (raw_image[..., -1][..., None] > 0) * 1
+ raw_image = raw_image[..., :3] * mask
+ input_cond = torch.from_numpy(np.array(raw_image)[None, ...]).to(device)
+
+ with torch.no_grad():
+ latent = torch.randn(1, config.model.num_prims, 1, 4, 4, 4)
+ batch = {}
+ inf_bs = 1
+ inf_x = torch.randn(inf_bs, config.model.num_prims, 68).to(device)
+ y = conditioner.encoder(input_cond)
+ model_kwargs = dict(y=y[:inf_bs, ...], precision_dtype=precision_dtype, enable_amp=amp)
+ if input_cfg >= 0:
+ model_kwargs['cfg_scale'] = input_cfg
+ for samples in sample_fn(fwd_fn, inf_x.shape, inf_x, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device):
+ final_samples = samples
+ recon_param = final_samples["sample"].reshape(inf_bs, config.model.num_prims, -1)
+ if perchannel_norm:
+ recon_param = recon_param / config.model.latent_nf * latent_std + latent_mean
+ recon_srt_param = recon_param[:, :, 0:4]
+ recon_feat_param = recon_param[:, :, 4:] # [8, 2048, 64]
+ recon_feat_param_list = []
+ # one-by-one to avoid oom
+ for inf_bidx in range(inf_bs):
+ if not perchannel_norm:
+ decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:]) / config.model.latent_nf)
+ else:
+ decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:]))
+ recon_feat_param_list.append(decoded.detach())
+ recon_feat_param = torch.concat(recon_feat_param_list, dim=0)
+ # invert normalization
+ if not perchannel_norm:
+ recon_srt_param[:, :, 0:1] = (recon_srt_param[:, :, 0:1] / 10) + 0.05
+ recon_feat_param[:, 0:1, ...] /= 5.
+ recon_feat_param[:, 1:, ...] = (recon_feat_param[:, 1:, ...] + 1) / 2.
+ recon_feat_param = recon_feat_param.reshape(inf_bs, config.model.num_prims, -1)
+ recon_param = torch.concat([recon_srt_param, recon_feat_param], dim=-1)
+ visualize_video_primvolume(config.output_dir, batch, recon_param, 60, rm, device)
+ prim_params = {'srt_param': recon_srt_param[0].detach().cpu(), 'feat_param': recon_feat_param[0].detach().cpu()}
+ torch.save({'model_state_dict': prim_params}, "{}/denoised.pt".format(config.output_dir))
+
+ # exporting GLB mesh
+ denoise_param_path = os.path.join(config.output_dir, 'denoised.pt')
+ primx_ckpt_weight = torch.load(denoise_param_path, map_location='cpu')['model_state_dict']
+ model_primx.load_state_dict(ckpt_weight)
+ model_primx.to(device)
+ model_primx.eval()
+ with torch.no_grad():
+ model_primx.srt_param[:, 1:4] *= 0.85
+ extract_texmesh(config.inference, model_primx, output_glb_path, device)
+
+ return output_rgb_video_path, output_prim_video_path, output_mat_video_path, output_glb_path
+
+# gradio UI
+_TITLE = '''3DTopia-XL'''
+
+_DESCRIPTION = '''
+
+
+
+
+
+* Now we offer 1) single image conditioned model, we will release 2) multiview images conditioned model and 3) pure text conditioned model in the future!
+* If you find the output unsatisfying, try using different seeds!
+'''
+
+block = gr.Blocks(title=_TITLE).queue()
+with block:
+ with gr.Row():
+ with gr.Column(scale=1):
+ gr.Markdown('# ' + _TITLE)
+ gr.Markdown(_DESCRIPTION)
+
+ with gr.Row(variant='panel'):
+ with gr.Column(scale=1):
+ # input image
+ input_image = gr.Image(label="image", type='pil')
+ # inference steps
+ input_num_steps = gr.Slider(label="inference steps", minimum=1, maximum=100, step=1, value=25)
+ # random seed
+ input_cfg = gr.Slider(label="CFG scale", minimum=0, maximum=15, step=1, value=6)
+ # random seed
+ input_seed = gr.Slider(label="random seed", minimum=0, maximum=100000, step=1, value=42)
+ # gen button
+ button_gen = gr.Button("Generate")
+
+ with gr.Column(scale=1):
+ with gr.Tab("Video"):
+ # final video results
+ output_rgb_video = gr.Video(label="video")
+ output_prim_video = gr.Video(label="video")
+ output_mat_video = gr.Video(label="video")
+ with gr.Tab("GLB"):
+ # glb file
+ output_glb = gr.File(label="glb")
+
+ button_gen.click(process, inputs=[input_image, input_num_steps, input_seed, input_cfg], outputs=[output_rgb_video, output_prim_video, output_mat_video, output_glb])
+
+ gr.Examples(
+ examples=[
+ "assets/examples/fruit_elephant.jpg",
+ "assets/examples/mei_ling_panda.png",
+ "assets/examples/shuai_panda_notail.png",
+ ],
+ inputs=[input_image],
+ outputs=[output_rgb_video, output_prim_video, output_mat_video, output_glb],
+ fn=lambda x: process(input_image=x),
+ cache_examples=False,
+ label='Single Image to 3D PBR Asset'
+ )
+
+block.launch(server_name="0.0.0.0", share=True)
\ No newline at end of file
diff --git a/assets/examples/blue_cat.png b/assets/examples/blue_cat.png
new file mode 100644
index 0000000000000000000000000000000000000000..aa933e07cc2d6b058900730649637d75be19ef5f
Binary files /dev/null and b/assets/examples/blue_cat.png differ
diff --git a/assets/examples/bubble_mart_blue.png b/assets/examples/bubble_mart_blue.png
new file mode 100644
index 0000000000000000000000000000000000000000..af870322d4a8a2f237546fbea9560bb8e5f50364
Binary files /dev/null and b/assets/examples/bubble_mart_blue.png differ
diff --git a/assets/examples/bulldog.png b/assets/examples/bulldog.png
new file mode 100644
index 0000000000000000000000000000000000000000..16c598a8133643898408ea806b69d5b18c53be7d
Binary files /dev/null and b/assets/examples/bulldog.png differ
diff --git a/assets/examples/ceramic.png b/assets/examples/ceramic.png
new file mode 100644
index 0000000000000000000000000000000000000000..46a2a336ea869397376bab68cbc45c690cc6c617
Binary files /dev/null and b/assets/examples/ceramic.png differ
diff --git a/assets/examples/chair_watermelon.png b/assets/examples/chair_watermelon.png
new file mode 100644
index 0000000000000000000000000000000000000000..52b39917abcbd2f1eef9b7c8cf9aa602bddde1bf
Binary files /dev/null and b/assets/examples/chair_watermelon.png differ
diff --git a/assets/examples/cup_rgba.png b/assets/examples/cup_rgba.png
new file mode 100644
index 0000000000000000000000000000000000000000..e730cdd960b8e552dea17959bd54f52cb6f941ce
Binary files /dev/null and b/assets/examples/cup_rgba.png differ
diff --git a/assets/examples/cute_horse.jpg b/assets/examples/cute_horse.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ec8807d313b983e3cc34ee89bbf3f312d6ce66eb
Binary files /dev/null and b/assets/examples/cute_horse.jpg differ
diff --git a/assets/examples/earphone.jpg b/assets/examples/earphone.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..498e4196b0d68f8809d049e7178b80592a31a0a2
Binary files /dev/null and b/assets/examples/earphone.jpg differ
diff --git a/assets/examples/firedragon.png b/assets/examples/firedragon.png
new file mode 100644
index 0000000000000000000000000000000000000000..6d6c54180f5a2d362c5eb9b54a4ef3a3222516eb
Binary files /dev/null and b/assets/examples/firedragon.png differ
diff --git a/assets/examples/fox.jpg b/assets/examples/fox.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1f2efc1c3a9c4ad8f36ad93082c124c91a6e9ef7
Binary files /dev/null and b/assets/examples/fox.jpg differ
diff --git a/assets/examples/fruit_elephant.jpg b/assets/examples/fruit_elephant.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ef8eaf3b88ae0a38272b34802fe40032055afa58
Binary files /dev/null and b/assets/examples/fruit_elephant.jpg differ
diff --git a/assets/examples/hatsune_miku.png b/assets/examples/hatsune_miku.png
new file mode 100644
index 0000000000000000000000000000000000000000..2fecf005fdd56a396c4894256fbb98fcc1c4dd8f
Binary files /dev/null and b/assets/examples/hatsune_miku.png differ
diff --git a/assets/examples/ikun_rgba.png b/assets/examples/ikun_rgba.png
new file mode 100644
index 0000000000000000000000000000000000000000..985c37037d83a75954da4e104bbbec4f550130e2
Binary files /dev/null and b/assets/examples/ikun_rgba.png differ
diff --git a/assets/examples/mailbox.png b/assets/examples/mailbox.png
new file mode 100644
index 0000000000000000000000000000000000000000..cce19d56d5a68b7eeec6848dd97a1ca7be30b520
Binary files /dev/null and b/assets/examples/mailbox.png differ
diff --git a/assets/examples/mario.png b/assets/examples/mario.png
new file mode 100644
index 0000000000000000000000000000000000000000..d9805fdcb31e2f7f036e830e03a045e169319ddf
Binary files /dev/null and b/assets/examples/mario.png differ
diff --git a/assets/examples/mei_ling_panda.png b/assets/examples/mei_ling_panda.png
new file mode 100644
index 0000000000000000000000000000000000000000..e7d5392b77385ff9670579ef4a10aaa5211e8067
Binary files /dev/null and b/assets/examples/mei_ling_panda.png differ
diff --git a/assets/examples/mushroom_teapot.jpg b/assets/examples/mushroom_teapot.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a6c767354305f5467a4c0d5f199eee2a120f4501
Binary files /dev/null and b/assets/examples/mushroom_teapot.jpg differ
diff --git a/assets/examples/pikachu.png b/assets/examples/pikachu.png
new file mode 100644
index 0000000000000000000000000000000000000000..e7579c16957a3e13b80d53cf0a41ddfdfd47b92d
Binary files /dev/null and b/assets/examples/pikachu.png differ
diff --git a/assets/examples/potplant_rgba.png b/assets/examples/potplant_rgba.png
new file mode 100644
index 0000000000000000000000000000000000000000..4aee4cdf0be69650d465faaa8585e01092e41ccb
Binary files /dev/null and b/assets/examples/potplant_rgba.png differ
diff --git a/assets/examples/seed_frog.png b/assets/examples/seed_frog.png
new file mode 100644
index 0000000000000000000000000000000000000000..35524f5a10fd205370ccd5881aa6deab5431e771
Binary files /dev/null and b/assets/examples/seed_frog.png differ
diff --git a/assets/examples/shuai_panda_notail.png b/assets/examples/shuai_panda_notail.png
new file mode 100644
index 0000000000000000000000000000000000000000..7a4fb91fc6fc2c3432ab42a2b83c1cc80760a3b9
Binary files /dev/null and b/assets/examples/shuai_panda_notail.png differ
diff --git a/assets/examples/yellow_duck.png b/assets/examples/yellow_duck.png
new file mode 100644
index 0000000000000000000000000000000000000000..02246dd65c52e0b2cf5a6daba96deeab926928ff
Binary files /dev/null and b/assets/examples/yellow_duck.png differ
diff --git a/configs/inference_dit.yml b/configs/inference_dit.yml
new file mode 100644
index 0000000000000000000000000000000000000000..0820949ffd5217cb3d43b9632fc6e4198b15d277
--- /dev/null
+++ b/configs/inference_dit.yml
@@ -0,0 +1,97 @@
+debug: False
+root_data_dir: ./runs
+checkpoint_path:
+global_seed: 42
+
+inference:
+ input_dir:
+ ddim: 25
+ cfg: 6
+ seed: ${global_seed}
+ precision: fp16
+ export_glb: True
+ decimate: 100000
+ mc_resolution: 256
+ batch_size: 4096
+ remesh: False
+
+image_height: 518
+image_width: 518
+
+model:
+ class_name: models.primsdf.PrimSDF
+ num_prims: 2048
+ dim_feat: 6
+ prim_shape: 8
+ init_scale: 0.05 # useless if auto_scale_init == True
+ sdf2alpha_var: 0.005
+ auto_scale_init: True
+ init_sampling: uniform
+ vae:
+ class_name: models.vae3d_dib.VAE
+ in_channels: ${model.dim_feat}
+ latent_channels: 1
+ out_channels: ${model.vae.in_channels}
+ down_channels: [32, 256]
+ mid_attention: True
+ up_channels: [256, 32]
+ layers_per_block: 2
+ gradient_checkpointing: False
+ vae_checkpoint_path:
+ conditioner:
+ class_name: models.conditioner.image.ImageConditioner
+ num_prims: ${model.num_prims}
+ dim_feat: ${model.dim_feat}
+ prim_shape: ${model.prim_shape}
+ sample_view: False
+ encoder_config:
+ class_name: models.conditioner.image_dinov2.Dinov2Wrapper
+ model_name: dinov2_vitb14_reg
+ freeze: True
+ generator:
+ class_name: models.dit_crossattn.DiT
+ seq_length: ${model.num_prims}
+ in_channels: 68 # equals to model.vae.latent_channels * latent_dim^3
+ condition_channels: 768
+ hidden_size: 1152
+ depth: 28
+ num_heads: 16
+ attn_proj_bias: True
+ cond_drop_prob: 0.1
+ gradient_checkpointing: False
+ latent_nf: 1.0
+ latent_mean: [ 0.0442, -0.0029, -0.0425, -0.0043, -0.4086, -0.2906, -0.7002, -0.0852, -0.4446, -0.6896, -0.7344, -0.3524, -0.5488, -0.4313, -1.1715, -0.0875, -0.6131, -0.3924, -0.7335, -0.3749, 0.4658, -0.0236, 0.8362, 0.3388, 0.0188, 0.5988, -0.1853, 1.1579, 0.6240, 0.0758, 0.9641, 0.6586, 0.6260, 0.2384, 0.7798, 0.8297, -0.6543, -0.4441, -1.3887, -0.0393, -0.9008, -0.8616, -1.7434, -0.1328, -0.8119, -0.8225, -1.8533, -0.0444, -1.0510, -0.5158, -1.1907, -0.5265, 0.2832, 0.6037, 0.5981, 0.5461, 0.4366, 0.4144, 0.7219, 0.5722, 0.5937, 0.5598, 0.9414, 0.7419, 0.2102, 0.3388, 0.4501, 0.5166]
+ latent_std: [0.0219, 0.3707, 0.3911, 0.3610, 0.7549, 0.7909, 0.9691, 0.9193, 0.8218, 0.9389, 1.1785, 1.0254, 0.6376, 0.6568, 0.7892, 0.8468, 0.8775, 0.7920, 0.9037, 0.9329, 0.9196, 1.1123, 1.3041, 1.0955, 1.2727, 1.6565, 1.8502, 1.7006, 0.8973, 1.0408, 1.2034, 1.2703, 1.0373, 1.0486, 1.0716, 0.9746, 0.7088, 0.8685, 1.0030, 0.9504, 1.0410, 1.3033, 1.5368, 1.4386, 0.6142, 0.6887, 0.9085, 0.9903, 1.0190, 0.9302, 1.0121, 0.9964, 1.1474, 1.2729, 1.4627, 1.1404, 1.3713, 1.6692, 1.8424, 1.5047, 1.1356, 1.2369, 1.3554, 1.1848, 1.1319, 1.0822, 1.1972, 0.9916]
+
+diffusion:
+ timestep_respacing:
+ noise_schedule: squaredcos_cap_v2
+ diffusion_steps: 1000
+ parameterization: v
+
+rm:
+ volradius: 10000.0
+ dt: 1.0
+
+optimizer:
+ class_name: torch.optim.AdamW
+ lr: 0.0001
+ weight_decay: 0
+
+scheduler:
+ class_name: dva.scheduler.CosineWarmupScheduler
+ warmup_iters: 3000
+ max_iters: 200000
+
+train:
+ batch_size: 8
+ n_workers: 4
+ n_epochs: 1000
+ log_every_n_steps: 50
+ summary_every_n_steps: 10000
+ ckpt_every_n_steps: 10000
+ amp: False
+ precision: tf32
+
+tag: 3dtopia-xl-sview
+output_dir: ${root_data_dir}/inference/${tag}
diff --git a/dva/__init__.py b/dva/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d09a529b266d66bd8daa5fc6c6c1c655eb10b83
--- /dev/null
+++ b/dva/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/dva/attr_dict.py b/dva/attr_dict.py
new file mode 100644
index 0000000000000000000000000000000000000000..84b85c9d3c5f6065a919fffa80ea36ed99c7e848
--- /dev/null
+++ b/dva/attr_dict.py
@@ -0,0 +1,66 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import json
+
+
+class AttrDict:
+ def __init__(self, entries):
+ self.add_entries_(entries)
+
+ def keys(self):
+ return self.__dict__.keys()
+
+ def values(self):
+ return self.__dict__.values()
+
+ def __getitem__(self, key):
+ return self.__dict__[key]
+
+ def __setitem__(self, key, value):
+ self.__dict__[key] = value
+
+ def __delitem__(self, key):
+ return self.__dict__.__delitem__(key)
+
+ def __contains__(self, key):
+ return key in self.__dict__
+
+ def __repr__(self):
+ return self.__dict__.__repr__()
+
+ def __getattr__(self, attr):
+ if attr.startswith("__"):
+ return self.__getattribute__(attr)
+ return self.__dict__[attr]
+
+ def items(self):
+ return self.__dict__.items()
+
+ def __iter__(self):
+ return iter(self.items())
+
+ def add_entries_(self, entries, overwrite=True):
+ for key, value in entries.items():
+ if key not in self.__dict__:
+ if isinstance(value, dict):
+ self.__dict__[key] = AttrDict(value)
+ else:
+ self.__dict__[key] = value
+ else:
+ if isinstance(value, dict):
+ self.__dict__[key].add_entries_(entries=value, overwrite=overwrite)
+ elif overwrite or self.__dict__[key] is None:
+ self.__dict__[key] = value
+
+ def serialize(self):
+ return json.dumps(self, default=self.obj_to_dict, indent=4)
+
+ def obj_to_dict(self, obj):
+ return obj.__dict__
+
+ def get(self, key, default=None):
+ return self.__dict__.get(key, default)
diff --git a/dva/geom.py b/dva/geom.py
new file mode 100644
index 0000000000000000000000000000000000000000..0dc9aa6bed0e7e440c41384af2a867ebe4f8d131
--- /dev/null
+++ b/dva/geom.py
@@ -0,0 +1,653 @@
+from typing import Optional
+import numpy as np
+import torch as th
+import torch.nn.functional as F
+import torch.nn as nn
+
+from sklearn.neighbors import KDTree
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+# NOTE: we need pytorch3d primarily for UV rasterization things
+from pytorch3d.renderer.mesh.rasterize_meshes import rasterize_meshes
+from pytorch3d.structures import Meshes
+from typing import Union, Optional, Tuple
+import trimesh
+from trimesh import Trimesh
+from trimesh.triangles import points_to_barycentric
+
+try:
+ # pyre-fixme[21]: Could not find module `igl`.
+ from igl import point_mesh_squared_distance # @manual
+
+ # pyre-fixme[3]: Return type must be annotated.
+ # pyre-fixme[2]: Parameter must be annotated.
+ def closest_point(mesh, points):
+ """Helper function that mimics trimesh.proximity.closest_point but uses
+ IGL for faster queries."""
+ v = mesh.vertices
+ vi = mesh.faces
+ dist, face_idxs, p = point_mesh_squared_distance(points, v, vi)
+ return p, dist, face_idxs
+
+except ImportError:
+ from trimesh.proximity import closest_point
+
+
+def closest_point_barycentrics(v, vi, points):
+ """Given a 3D mesh and a set of query points, return closest point barycentrics
+ Args:
+ v: np.array (float)
+ [N, 3] mesh vertices
+
+ vi: np.array (int)
+ [N, 3] mesh triangle indices
+
+ points: np.array (float)
+ [M, 3] query points
+
+ Returns:
+ Tuple[approx, barys, interp_idxs, face_idxs]
+ approx: [M, 3] approximated (closest) points on the mesh
+ barys: [M, 3] barycentric weights that produce "approx"
+ interp_idxs: [M, 3] vertex indices for barycentric interpolation
+ face_idxs: [M] face indices for barycentric interpolation. interp_idxs = vi[face_idxs]
+ """
+ mesh = Trimesh(vertices=v, faces=vi, process=False)
+ p, _, face_idxs = closest_point(mesh, points)
+ p = p.reshape((points.shape[0], 3))
+ face_idxs = face_idxs.reshape((points.shape[0],))
+ barys = points_to_barycentric(mesh.triangles[face_idxs], p)
+ b0, b1, b2 = np.split(barys, 3, axis=1)
+
+ interp_idxs = vi[face_idxs]
+ v0 = v[interp_idxs[:, 0]]
+ v1 = v[interp_idxs[:, 1]]
+ v2 = v[interp_idxs[:, 2]]
+ approx = b0 * v0 + b1 * v1 + b2 * v2
+ return approx, barys, interp_idxs, face_idxs
+
+def make_uv_face_index(
+ vt: th.Tensor,
+ vti: th.Tensor,
+ uv_shape: Union[Tuple[int, int], int],
+ flip_uv: bool = True,
+ device: Optional[Union[str, th.device]] = None,
+):
+ """Compute a UV-space face index map identifying which mesh face contains each
+ texel. For texels with no assigned triangle, the index will be -1."""
+
+ if isinstance(uv_shape, int):
+ uv_shape = (uv_shape, uv_shape)
+
+ uv_max_shape_ind = uv_shape.index(max(uv_shape))
+ uv_min_shape_ind = uv_shape.index(min(uv_shape))
+ uv_ratio = uv_shape[uv_max_shape_ind] / uv_shape[uv_min_shape_ind]
+
+ if device is not None:
+ if isinstance(device, str):
+ dev = th.device(device)
+ else:
+ dev = device
+ assert dev.type == "cuda"
+ else:
+ dev = th.device("cuda")
+
+ vt = 1.0 - vt.clone()
+
+ if flip_uv:
+ vt = vt.clone()
+ vt[:, 1] = 1 - vt[:, 1]
+ vt_pix = 2.0 * vt.to(dev) - 1.0
+ vt_pix = th.cat([vt_pix, th.ones_like(vt_pix[:, 0:1])], dim=1)
+
+ vt_pix[:, uv_min_shape_ind] *= uv_ratio
+ meshes = Meshes(vt_pix[np.newaxis], vti[np.newaxis].to(dev))
+ with th.no_grad():
+ face_index, _, _, _ = rasterize_meshes(
+ meshes, uv_shape, faces_per_pixel=1, z_clip_value=0.0, bin_size=0
+ )
+ face_index = face_index[0, ..., 0]
+ return face_index
+
+
+def make_uv_vert_index(
+ vt: th.Tensor,
+ vi: th.Tensor,
+ vti: th.Tensor,
+ uv_shape: Union[Tuple[int, int], int],
+ flip_uv: bool = True,
+):
+ """Compute a UV-space vertex index map identifying which mesh vertices
+ comprise the triangle containing each texel. For texels with no assigned
+ triangle, all indices will be -1.
+ """
+ face_index_map = make_uv_face_index(vt, vti, uv_shape, flip_uv)
+ vert_index_map = vi[face_index_map.clamp(min=0)]
+ vert_index_map[face_index_map < 0] = -1
+ return vert_index_map.long()
+
+
+def bary_coords(points: th.Tensor, triangles: th.Tensor, eps: float = 1.0e-6):
+ """Computes barycentric coordinates for a set of 2D query points given
+ coordintes for the 3 vertices of the enclosing triangle for each point."""
+ x = points[:, 0] - triangles[2, :, 0]
+ x1 = triangles[0, :, 0] - triangles[2, :, 0]
+ x2 = triangles[1, :, 0] - triangles[2, :, 0]
+ y = points[:, 1] - triangles[2, :, 1]
+ y1 = triangles[0, :, 1] - triangles[2, :, 1]
+ y2 = triangles[1, :, 1] - triangles[2, :, 1]
+ denom = y2 * x1 - y1 * x2
+ n0 = y2 * x - x2 * y
+ n1 = x1 * y - y1 * x
+
+ # Small epsilon to prevent divide-by-zero error.
+ denom = th.where(denom >= 0, denom.clamp(min=eps), denom.clamp(max=-eps))
+
+ bary_0 = n0 / denom
+ bary_1 = n1 / denom
+ bary_2 = 1.0 - bary_0 - bary_1
+
+ return th.stack((bary_0, bary_1, bary_2))
+
+
+def make_uv_barys(
+ vt: th.Tensor,
+ vti: th.Tensor,
+ uv_shape: Union[Tuple[int, int], int],
+ flip_uv: bool = True,
+):
+ """Compute a UV-space barycentric map where each texel contains barycentric
+ coordinates for that texel within its enclosing UV triangle. For texels
+ with no assigned triangle, all 3 barycentric coordinates will be 0.
+ """
+ if isinstance(uv_shape, int):
+ uv_shape = (uv_shape, uv_shape)
+
+ if flip_uv:
+ # Flip here because texture coordinates in some of our topo files are
+ # stored in OpenGL convention with Y=0 on the bottom of the texture
+ # unlike numpy/torch arrays/tensors.
+ vt = vt.clone()
+ vt[:, 1] = 1 - vt[:, 1]
+
+ face_index_map = make_uv_face_index(vt, vti, uv_shape, flip_uv=False)
+ vti_map = vti.long()[face_index_map.clamp(min=0)]
+
+ uv_max_shape_ind = uv_shape.index(max(uv_shape))
+ uv_min_shape_ind = uv_shape.index(min(uv_shape))
+ uv_ratio = uv_shape[uv_max_shape_ind] / uv_shape[uv_min_shape_ind]
+ vt = vt.clone()
+ vt = vt * 2 - 1
+ vt[:, uv_min_shape_ind] *= uv_ratio
+ uv_tri_uvs = vt[vti_map].permute(2, 0, 1, 3)
+
+ uv_grid = th.meshgrid(
+ th.linspace(0.5, uv_shape[0] - 0.5, uv_shape[0]) / uv_shape[0],
+ th.linspace(0.5, uv_shape[1] - 0.5, uv_shape[1]) / uv_shape[1],
+ )
+ uv_grid = th.stack(uv_grid[::-1], dim=2).to(uv_tri_uvs)
+ uv_grid = uv_grid * 2 - 1
+ uv_grid[..., uv_min_shape_ind] *= uv_ratio
+
+ bary_map = bary_coords(uv_grid.view(-1, 2), uv_tri_uvs.view(3, -1, 2))
+ bary_map = bary_map.permute(1, 0).view(uv_shape[0], uv_shape[1], 3)
+ bary_map[face_index_map < 0] = 0
+ return face_index_map, bary_map
+
+
+def index_image_impaint(
+ index_image: th.Tensor,
+ bary_image: Optional[th.Tensor] = None,
+ distance_threshold=100.0,
+):
+ # getting the mask around the indexes?
+ if len(index_image.shape) == 3:
+ valid_index = (index_image != -1).any(dim=-1)
+ elif len(index_image.shape) == 2:
+ valid_index = index_image != -1
+ else:
+ raise ValueError("`index_image` should be a [H,W] or [H,W,C] image")
+
+ invalid_index = ~valid_index
+
+ device = index_image.device
+
+ valid_ij = th.stack(th.where(valid_index), dim=-1)
+ invalid_ij = th.stack(th.where(invalid_index), dim=-1)
+ lookup_valid = KDTree(valid_ij.cpu().numpy())
+
+ dists, idxs = lookup_valid.query(invalid_ij.cpu())
+
+ # TODO: try average?
+ idxs = th.as_tensor(idxs, device=device)[..., 0]
+ dists = th.as_tensor(dists, device=device)[..., 0]
+
+ dist_mask = dists < distance_threshold
+
+ invalid_border = th.zeros_like(invalid_index)
+ invalid_border[invalid_index] = dist_mask
+
+ invalid_src_ij = valid_ij[idxs][dist_mask]
+ invalid_dst_ij = invalid_ij[dist_mask]
+
+ index_image_imp = index_image.clone()
+
+ index_image_imp[invalid_dst_ij[:, 0], invalid_dst_ij[:, 1]] = index_image[
+ invalid_src_ij[:, 0], invalid_src_ij[:, 1]
+ ]
+
+ if bary_image is not None:
+ bary_image_imp = bary_image.clone()
+
+ bary_image_imp[invalid_dst_ij[:, 0], invalid_dst_ij[:, 1]] = bary_image[
+ invalid_src_ij[:, 0], invalid_src_ij[:, 1]
+ ]
+
+ return index_image_imp, bary_image_imp
+ return index_image_imp
+
+
+class GeometryModule(nn.Module):
+ def __init__(
+ self,
+ v,
+ vi,
+ vt,
+ vti,
+ uv_size,
+ v2uv: Optional[th.Tensor] = None,
+ flip_uv=False,
+ impaint=False,
+ impaint_threshold=100.0,
+ ):
+ super().__init__()
+
+ self.register_buffer("v", th.as_tensor(v))
+ self.register_buffer("vi", th.as_tensor(vi))
+ self.register_buffer("vt", th.as_tensor(vt))
+ self.register_buffer("vti", th.as_tensor(vti))
+ if v2uv is not None:
+ self.register_buffer("v2uv", th.as_tensor(v2uv, dtype=th.int64))
+
+ # TODO: should we just pass topology here?
+ # self.n_verts = v2uv.shape[0]
+ self.n_verts = vi.max() + 1
+
+ self.uv_size = uv_size
+
+ # TODO: can't we just index face_index?
+ index_image = make_uv_vert_index(
+ self.vt, self.vi, self.vti, uv_shape=uv_size, flip_uv=flip_uv
+ ).cpu()
+ face_index, bary_image = make_uv_barys(
+ self.vt, self.vti, uv_shape=uv_size, flip_uv=flip_uv
+ )
+ if impaint:
+ if min(uv_size) >= 1024:
+ logger.info(
+ "impainting index image might take a while for sizes >= 1024"
+ )
+
+ index_image, bary_image = index_image_impaint(
+ index_image, bary_image, impaint_threshold
+ )
+ # TODO: we can avoid doing this 2x
+ face_index = index_image_impaint(
+ face_index, distance_threshold=impaint_threshold
+ )
+
+ self.register_buffer("index_image", index_image.cpu())
+ self.register_buffer("bary_image", bary_image.cpu())
+ self.register_buffer("face_index_image", face_index.cpu())
+
+ def render_index_images(self, uv_size, flip_uv=False, impaint=False):
+ index_image = make_uv_vert_index(
+ self.vt, self.vi, self.vti, uv_shape=uv_size, flip_uv=flip_uv
+ )
+ face_image, bary_image = make_uv_barys(
+ self.vt, self.vti, uv_shape=uv_size, flip_uv=flip_uv
+ )
+
+ if impaint:
+ index_image, bary_image = index_image_impaint(
+ index_image,
+ bary_image,
+ )
+
+ return index_image, face_image, bary_image
+
+ def vn(self, verts):
+ return vert_normals(verts, self.vi[np.newaxis].to(th.long))
+
+ def to_uv(self, values):
+ return values_to_uv(values, self.index_image, self.bary_image)
+
+ def from_uv(self, values_uv):
+ # TODO: we need to sample this
+ return sample_uv(values_uv, self.vt, self.v2uv.to(th.long))
+
+ def rand_sample_3d_uv(self, count, uv_img):
+ """
+ Sample a set of 3D points on the surface of mesh, return corresponding interpolated values in UV space.
+
+ Args:
+ count - num of 3D points to be sampled
+
+ uv_img - the image in uv space to be sampled, e.g., texture
+ """
+ _mesh = Trimesh(vertices=self.v.detach().cpu().numpy(), faces=self.vi.detach().cpu().numpy(), process=False)
+ points, _ = trimesh.sample.sample_surface(_mesh, count)
+ return self.sample_uv_from_3dpts(points, uv_img)
+
+ def sample_uv_from_3dpts(self, points, uv_img):
+ num_pts = points.shape[0]
+ approx, barys, interp_idxs, face_idxs = closest_point_barycentrics(self.v.detach().cpu().numpy(), self.vi.detach().cpu().numpy(), points)
+ interp_uv_coords = self.vt[interp_idxs, :] # [N, 3, 2]
+ # do bary interp first to get interp_uv_coord in high-reso uv space
+ target_uv_coords = th.sum(interp_uv_coords * th.from_numpy(barys)[..., None], dim=1).float()
+ # then directly sample from uv space
+ sampled_values = sample_uv(values_uv=uv_img.permute(2, 0, 1)[None, ...], uv_coords=target_uv_coords) # [1, count, c]
+ approx_values = sampled_values[0].reshape(num_pts, uv_img.shape[2])
+ return approx_values.numpy(), points
+
+ def vert_sample_uv(self, uv_img):
+ count = self.v.shape[0]
+ points = self.v.detach().cpu().numpy()
+ approx_values, _ = self.sample_uv_from_3dpts(points, uv_img)
+ return approx_values
+
+
+def sample_uv(
+ values_uv,
+ uv_coords,
+ v2uv: Optional[th.Tensor] = None,
+ mode: str = "bilinear",
+ align_corners: bool = True,
+ flip_uvs: bool = False,
+):
+ batch_size = values_uv.shape[0]
+
+ if flip_uvs:
+ uv_coords = uv_coords.clone()
+ uv_coords[:, 1] = 1.0 - uv_coords[:, 1]
+
+ # uv_coords_norm is [1, N, 1, 2] afterwards
+ uv_coords_norm = (uv_coords * 2.0 - 1.0)[np.newaxis, :, np.newaxis].expand(
+ batch_size, -1, -1, -1
+ )
+ # uv_shape = values_uv.shape[-2:]
+ # uv_max_shape_ind = uv_shape.index(max(uv_shape))
+ # uv_min_shape_ind = uv_shape.index(min(uv_shape))
+ # uv_ratio = uv_shape[uv_max_shape_ind] / uv_shape[uv_min_shape_ind]
+ # uv_coords_norm[..., uv_min_shape_ind] *= uv_ratio
+
+ values = (
+ F.grid_sample(values_uv, uv_coords_norm, align_corners=align_corners, mode=mode)
+ .squeeze(-1)
+ .permute((0, 2, 1))
+ )
+
+ if v2uv is not None:
+ values_duplicate = values[:, v2uv]
+ values = values_duplicate.mean(2)
+
+ return values
+
+
+def values_to_uv(values, index_img, bary_img):
+ uv_size = index_img.shape
+ index_mask = th.all(index_img != -1, dim=-1)
+ idxs_flat = index_img[index_mask].to(th.int64)
+ bary_flat = bary_img[index_mask].to(th.float32)
+ # NOTE: here we assume
+ values_flat = th.sum(values[:, idxs_flat].permute(0, 3, 1, 2) * bary_flat, dim=-1)
+ values_uv = th.zeros(
+ values.shape[0],
+ values.shape[-1],
+ uv_size[0],
+ uv_size[1],
+ dtype=values.dtype,
+ device=values.device,
+ )
+ values_uv[:, :, index_mask] = values_flat
+ return values_uv
+
+
+def face_normals(v, vi, eps: float = 1e-5):
+ pts = v[:, vi]
+ v0 = pts[:, :, 1] - pts[:, :, 0]
+ v1 = pts[:, :, 2] - pts[:, :, 0]
+ n = th.cross(v0, v1, dim=-1)
+ norm = th.norm(n, dim=-1, keepdim=True)
+ norm[norm < eps] = 1
+ n /= norm
+ return n
+
+
+def vert_normals(v, vi, eps: float = 1.0e-5):
+ fnorms = face_normals(v, vi)
+ fnorms = fnorms[:, :, None].expand(-1, -1, 3, -1).reshape(fnorms.shape[0], -1, 3)
+ vi_flat = vi.view(1, -1).expand(v.shape[0], -1)
+ vnorms = th.zeros_like(v)
+ for j in range(3):
+ vnorms[..., j].scatter_add_(1, vi_flat, fnorms[..., j])
+ norm = th.norm(vnorms, dim=-1, keepdim=True)
+ norm[norm < eps] = 1
+ vnorms /= norm
+ return vnorms
+
+
+def compute_view_cos(verts, faces, camera_pos):
+ vn = F.normalize(vert_normals(verts, faces), dim=-1)
+ v2c = F.normalize(verts - camera_pos[:, np.newaxis], dim=-1)
+ return th.einsum("bnd,bnd->bn", vn, v2c)
+
+
+def compute_tbn(geom, vt, vi, vti):
+ """Computes tangent, bitangent, and normal vectors given a mesh.
+ Args:
+ geom: [N, n_verts, 3] th.Tensor
+ Vertex positions.
+ vt: [n_uv_coords, 2] th.Tensor
+ UV coordinates.
+ vi: [..., 3] th.Tensor
+ Face vertex indices.
+ vti: [..., 3] th.Tensor
+ Face UV indices.
+ Returns:
+ [..., 3] th.Tensors for T, B, N.
+ """
+
+ v0 = geom[:, vi[..., 0]]
+ v1 = geom[:, vi[..., 1]]
+ v2 = geom[:, vi[..., 2]]
+ vt0 = vt[vti[..., 0]]
+ vt1 = vt[vti[..., 1]]
+ vt2 = vt[vti[..., 2]]
+
+ v01 = v1 - v0
+ v02 = v2 - v0
+ vt01 = vt1 - vt0
+ vt02 = vt2 - vt0
+ f = 1.0 / (
+ vt01[None, ..., 0] * vt02[None, ..., 1]
+ - vt01[None, ..., 1] * vt02[None, ..., 0]
+ )
+ tangent = f[..., None] * th.stack(
+ [
+ v01[..., 0] * vt02[None, ..., 1] - v02[..., 0] * vt01[None, ..., 1],
+ v01[..., 1] * vt02[None, ..., 1] - v02[..., 1] * vt01[None, ..., 1],
+ v01[..., 2] * vt02[None, ..., 1] - v02[..., 2] * vt01[None, ..., 1],
+ ],
+ dim=-1,
+ )
+ tangent = F.normalize(tangent, dim=-1)
+ normal = F.normalize(th.cross(v01, v02, dim=3), dim=-1)
+ bitangent = F.normalize(th.cross(tangent, normal, dim=3), dim=-1)
+
+ return tangent, bitangent, normal
+
+
+def compute_v2uv(n_verts, vi, vti, n_max=4):
+ """Computes mapping from vertex indices to texture indices.
+
+ Args:
+ vi: [F, 3], triangles
+ vti: [F, 3], texture triangles
+ n_max: int, max number of texture locations
+
+ Returns:
+ [n_verts, n_max], texture indices
+ """
+ v2uv_dict = {}
+ for i_v, i_uv in zip(vi.reshape(-1), vti.reshape(-1)):
+ v2uv_dict.setdefault(i_v, set()).add(i_uv)
+ assert len(v2uv_dict) == n_verts
+ v2uv = np.zeros((n_verts, n_max), dtype=np.int32)
+ for i in range(n_verts):
+ vals = sorted(list(v2uv_dict[i]))
+ v2uv[i, :] = vals[0]
+ v2uv[i, : len(vals)] = np.array(vals)
+ return v2uv
+
+
+def compute_neighbours(n_verts, vi, n_max_values=10):
+ """Computes first-ring neighbours given vertices and faces."""
+ n_vi = vi.shape[0]
+
+ adj = {i: set() for i in range(n_verts)}
+ for i in range(n_vi):
+ for idx in vi[i]:
+ adj[idx] |= set(vi[i]) - set([idx])
+
+ nbs_idxs = np.tile(np.arange(n_verts)[:, np.newaxis], (1, n_max_values))
+ nbs_weights = np.zeros((n_verts, n_max_values), dtype=np.float32)
+
+ for idx in range(n_verts):
+ n_values = min(len(adj[idx]), n_max_values)
+ nbs_idxs[idx, :n_values] = np.array(list(adj[idx]))[:n_values]
+ nbs_weights[idx, :n_values] = -1.0 / n_values
+
+ return nbs_idxs, nbs_weights
+
+
+def make_postex(v, idxim, barim):
+ return (
+ barim[None, :, :, 0, None] * v[:, idxim[:, :, 0]]
+ + barim[None, :, :, 1, None] * v[:, idxim[:, :, 1]]
+ + barim[None, :, :, 2, None] * v[:, idxim[:, :, 2]]
+ ).permute(0, 3, 1, 2)
+
+
+def matrix_to_axisangle(r):
+ th = th.arccos(0.5 * (r[..., 0, 0] + r[..., 1, 1] + r[..., 2, 2] - 1.0))[..., None]
+ vec = (
+ 0.5
+ * th.stack(
+ [
+ r[..., 2, 1] - r[..., 1, 2],
+ r[..., 0, 2] - r[..., 2, 0],
+ r[..., 1, 0] - r[..., 0, 1],
+ ],
+ dim=-1,
+ )
+ / th.sin(th)
+ )
+ return th, vec
+
+
+def axisangle_to_matrix(rvec):
+ theta = th.sqrt(1e-5 + th.sum(rvec**2, dim=-1))
+ rvec = rvec / theta[..., None]
+ costh = th.cos(theta)
+ sinth = th.sin(theta)
+ return th.stack(
+ (
+ th.stack(
+ (
+ rvec[..., 0] ** 2 + (1.0 - rvec[..., 0] ** 2) * costh,
+ rvec[..., 0] * rvec[..., 1] * (1.0 - costh) - rvec[..., 2] * sinth,
+ rvec[..., 0] * rvec[..., 2] * (1.0 - costh) + rvec[..., 1] * sinth,
+ ),
+ dim=-1,
+ ),
+ th.stack(
+ (
+ rvec[..., 0] * rvec[..., 1] * (1.0 - costh) + rvec[..., 2] * sinth,
+ rvec[..., 1] ** 2 + (1.0 - rvec[..., 1] ** 2) * costh,
+ rvec[..., 1] * rvec[..., 2] * (1.0 - costh) - rvec[..., 0] * sinth,
+ ),
+ dim=-1,
+ ),
+ th.stack(
+ (
+ rvec[..., 0] * rvec[..., 2] * (1.0 - costh) - rvec[..., 1] * sinth,
+ rvec[..., 1] * rvec[..., 2] * (1.0 - costh) + rvec[..., 0] * sinth,
+ rvec[..., 2] ** 2 + (1.0 - rvec[..., 2] ** 2) * costh,
+ ),
+ dim=-1,
+ ),
+ ),
+ dim=-2,
+ )
+
+
+def rotation_interp(r0, r1, alpha):
+ r0a = r0.view(-1, 3, 3)
+ r1a = r1.view(-1, 3, 3)
+ r = th.bmm(r0a.permute(0, 2, 1), r1a).view_as(r0)
+
+ th, rvec = matrix_to_axisangle(r)
+ rvec = rvec * (alpha * th)
+
+ r = axisangle_to_matrix(rvec)
+ return th.bmm(r0a, r.view(-1, 3, 3)).view_as(r0)
+
+
+def convert_camera_parameters(Rt, K):
+ R = Rt[:, :3, :3]
+ t = -R.permute(0, 2, 1).bmm(Rt[:, :3, 3].unsqueeze(2)).squeeze(2)
+ return dict(
+ campos=t,
+ camrot=R,
+ focal=K[:, :2, :2],
+ princpt=K[:, :2, 2],
+ )
+
+
+def project_points_multi(p, Rt, K, normalize=False, size=None):
+ """Project a set of 3D points into multiple cameras with a pinhole model.
+ Args:
+ p: [B, N, 3], input 3D points in world coordinates
+ Rt: [B, NC, 3, 4], extrinsics (where NC is the number of cameras to project to)
+ K: [B, NC, 3, 3], intrinsics
+ normalize: bool, whether to normalize coordinates to [-1.0, 1.0]
+ Returns:
+ tuple:
+ - [B, NC, N, 2] - projected points
+ - [B, NC, N] - their
+ """
+ B, N = p.shape[:2]
+ NC = Rt.shape[1]
+
+ Rt = Rt.reshape(B * NC, 3, 4)
+ K = K.reshape(B * NC, 3, 3)
+
+ # [B, N, 3] -> [B * NC, N, 3]
+ p = p[:, np.newaxis].expand(-1, NC, -1, -1).reshape(B * NC, -1, 3)
+ p_cam = p @ Rt[:, :3, :3].transpose(-2, -1) + Rt[:, :3, 3][:, np.newaxis]
+ p_pix = p_cam @ K.transpose(-2, -1)
+ p_depth = p_pix[:, :, 2:]
+ p_pix = (p_pix[..., :2] / p_depth).reshape(B, NC, N, 2)
+ p_depth = p_depth.reshape(B, NC, N)
+
+ if normalize:
+ assert size is not None
+ h, w = size
+ p_pix = (
+ 2.0 * p_pix / th.as_tensor([w, h], dtype=th.float32, device=p.device) - 1.0
+ )
+ return p_pix, p_depth
diff --git a/dva/io.py b/dva/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2875b51a2bfdbd2c00f904b6b732583fcd0b182
--- /dev/null
+++ b/dva/io.py
@@ -0,0 +1,56 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import json
+import cv2
+import numpy as np
+import copy
+import importlib
+from typing import Any, Dict
+
+def load_module(module_name, class_name=None, silent: bool = False):
+ module = importlib.import_module(module_name)
+ return getattr(module, class_name) if class_name else module
+
+
+def load_class(class_name):
+ return load_module(*class_name.rsplit(".", 1))
+
+
+def load_from_config(config, **kwargs):
+ """Instantiate an object given a config and arguments."""
+ assert "class_name" in config and "module_name" not in config
+ config = copy.deepcopy(config)
+ class_name = config.pop("class_name")
+ object_class = load_class(class_name)
+ return object_class(**config, **kwargs)
+
+
+def load_opencv_calib(extrin_path, intrin_path):
+ cameras = {}
+
+ fse = cv2.FileStorage()
+ fse.open(extrin_path, cv2.FileStorage_READ)
+
+ fsi = cv2.FileStorage()
+ fsi.open(intrin_path, cv2.FileStorage_READ)
+
+ names = [
+ fse.getNode("names").at(c).string() for c in range(fse.getNode("names").size())
+ ]
+
+ for camera in names:
+ rot = fse.getNode(f"R_{camera}").mat()
+ R = fse.getNode(f"Rot_{camera}").mat()
+ T = fse.getNode(f"T_{camera}").mat()
+ R_pred = cv2.Rodrigues(rot)[0]
+ assert np.all(np.isclose(R_pred, R))
+ K = fsi.getNode(f"K_{camera}").mat()
+ cameras[camera] = {
+ "Rt": np.concatenate([R, T], axis=1).astype(np.float32),
+ "K": K.astype(np.float32),
+ }
+ return cameras
\ No newline at end of file
diff --git a/dva/layers.py b/dva/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d219ecfed8c5e8ad2dfb1f7a1204bffa2ea2344
--- /dev/null
+++ b/dva/layers.py
@@ -0,0 +1,157 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch as th
+import torch.nn as nn
+
+import numpy as np
+
+from dva.mvp.models.utils import Conv2dWN, Conv2dWNUB, ConvTranspose2dWNUB, initmod
+
+
+class ConvBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ size,
+ lrelu_slope=0.2,
+ kernel_size=3,
+ padding=1,
+ wnorm_dim=0,
+ ):
+ super().__init__()
+
+ self.conv_resize = Conv2dWN(in_channels, out_channels, kernel_size=1)
+ self.conv1 = Conv2dWNUB(
+ in_channels,
+ in_channels,
+ kernel_size=kernel_size,
+ padding=padding,
+ height=size,
+ width=size,
+ )
+
+ self.lrelu1 = nn.LeakyReLU(lrelu_slope)
+ self.conv2 = Conv2dWNUB(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ padding=padding,
+ height=size,
+ width=size,
+ )
+ self.lrelu2 = nn.LeakyReLU(lrelu_slope)
+
+ def forward(self, x):
+ x_skip = self.conv_resize(x)
+ x = self.conv1(x)
+ x = self.lrelu1(x)
+ x = self.conv2(x)
+ x = self.lrelu2(x)
+ return x + x_skip
+
+
+def tile2d(x, size: int):
+ """Tile a given set of features into a convolutional map.
+
+ Args:
+ x: float tensor of shape [N, F]
+ size: int or a tuple
+
+ Returns:
+ a feature map [N, F, size[0], size[1]]
+ """
+ # size = size if isinstance(size, tuple) else (size, size)
+ # NOTE: expecting only int here (!!!)
+ return x[:, :, np.newaxis, np.newaxis].expand(-1, -1, size, size)
+
+
+def weights_initializer(m, alpha: float = 1.0):
+ return initmod(m, nn.init.calculate_gain("leaky_relu", alpha))
+
+
+class UNetWB(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ size,
+ n_init_ftrs=8,
+ out_scale=0.1,
+ ):
+ # super().__init__(*args, **kwargs)
+ super().__init__()
+
+ self.out_scale = 0.1
+
+ F = n_init_ftrs
+
+ # TODO: allow changing the size?
+ self.size = size
+
+ self.down1 = nn.Sequential(
+ Conv2dWNUB(in_channels, F, self.size // 2, self.size // 2, 4, 2, 1),
+ nn.LeakyReLU(0.2),
+ )
+ self.down2 = nn.Sequential(
+ Conv2dWNUB(F, 2 * F, self.size // 4, self.size // 4, 4, 2, 1),
+ nn.LeakyReLU(0.2),
+ )
+ self.down3 = nn.Sequential(
+ Conv2dWNUB(2 * F, 4 * F, self.size // 8, self.size // 8, 4, 2, 1),
+ nn.LeakyReLU(0.2),
+ )
+ self.down4 = nn.Sequential(
+ Conv2dWNUB(4 * F, 8 * F, self.size // 16, self.size // 16, 4, 2, 1),
+ nn.LeakyReLU(0.2),
+ )
+ self.down5 = nn.Sequential(
+ Conv2dWNUB(8 * F, 16 * F, self.size // 32, self.size // 32, 4, 2, 1),
+ nn.LeakyReLU(0.2),
+ )
+ self.up1 = nn.Sequential(
+ ConvTranspose2dWNUB(
+ 16 * F, 8 * F, self.size // 16, self.size // 16, 4, 2, 1
+ ),
+ nn.LeakyReLU(0.2),
+ )
+ self.up2 = nn.Sequential(
+ ConvTranspose2dWNUB(8 * F, 4 * F, self.size // 8, self.size // 8, 4, 2, 1),
+ nn.LeakyReLU(0.2),
+ )
+ self.up3 = nn.Sequential(
+ ConvTranspose2dWNUB(4 * F, 2 * F, self.size // 4, self.size // 4, 4, 2, 1),
+ nn.LeakyReLU(0.2),
+ )
+ self.up4 = nn.Sequential(
+ ConvTranspose2dWNUB(2 * F, F, self.size // 2, self.size // 2, 4, 2, 1),
+ nn.LeakyReLU(0.2),
+ )
+ self.up5 = nn.Sequential(
+ ConvTranspose2dWNUB(F, F, self.size, self.size, 4, 2, 1), nn.LeakyReLU(0.2)
+ )
+ self.out = Conv2dWNUB(
+ F + in_channels, out_channels, self.size, self.size, kernel_size=1
+ )
+ self.apply(lambda x: initmod(x, 0.2))
+ initmod(self.out, 1.0)
+
+ def forward(self, x):
+ x1 = x
+ x2 = self.down1(x1)
+ x3 = self.down2(x2)
+ x4 = self.down3(x3)
+ x5 = self.down4(x4)
+ x6 = self.down5(x5)
+ # TODO: switch to concat?
+ x = self.up1(x6) + x5
+ x = self.up2(x) + x4
+ x = self.up3(x) + x3
+ x = self.up4(x) + x2
+ x = self.up5(x)
+ x = th.cat([x, x1], dim=1)
+ return self.out(x) * self.out_scale
diff --git a/dva/losses.py b/dva/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..850475114bbbf06f552646b548d6f926813c0098
--- /dev/null
+++ b/dva/losses.py
@@ -0,0 +1,239 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch.nn as nn
+import torch as th
+import numpy as np
+
+import logging
+
+from .vgg import VGGLossMasked
+
+logger = logging.getLogger("dva.{__name__}")
+
+class DCTLoss(nn.Module):
+ def __init__(self, weights):
+ super().__init__()
+ self.weights = weights
+
+ def forward(self, inputs, preds, iteration=None):
+ loss_dict = {"loss_total": 0.0}
+ target = inputs['gt']
+ recon = preds['recon']
+ posterior = preds['posterior']
+ fft_gt = th.view_as_real(th.fft.fft(target.reshape(target.shape[0], -1)))
+ fft_recon = th.view_as_real(th.fft.fft(recon.reshape(recon.shape[0], -1)))
+ loss_recon_dct_l1 = th.mean(th.abs(fft_gt - fft_recon))
+ loss_recon_l1 = th.mean(th.abs(target - recon))
+ loss_kl = posterior.kl().mean()
+ loss_dict.update(loss_recon_l1=loss_recon_l1, loss_recon_dct_l1=loss_recon_dct_l1, loss_kl=loss_kl)
+ loss_total = self.weights.recon * loss_recon_dct_l1 + self.weights.kl * loss_kl
+
+ loss_dict["loss_total"] = loss_total
+ return loss_total, loss_dict
+
+class VAESepL2Loss(nn.Module):
+ def __init__(self, weights):
+ super().__init__()
+ self.weights = weights
+
+ def forward(self, inputs, preds, iteration=None):
+ loss_dict = {"loss_total": 0.0}
+ target = inputs['gt']
+ recon = preds['recon']
+ posterior = preds['posterior']
+ recon_diff = (target - recon) ** 2
+ loss_recon_sdf_l1 = th.mean(recon_diff[:, 0:1, ...])
+ loss_recon_rgb_l1 = th.mean(recon_diff[:, 1:4, ...])
+ loss_recon_mat_l1 = th.mean(recon_diff[:, 4:6, ...])
+ loss_kl = posterior.kl().mean()
+ loss_dict.update(loss_sdf_l1=loss_recon_sdf_l1, loss_rgb_l1=loss_recon_rgb_l1, loss_mat_l1=loss_recon_mat_l1, loss_kl=loss_kl)
+ loss_total = self.weights.sdf * loss_recon_sdf_l1 + self.weights.rgb * loss_recon_rgb_l1 + self.weights.mat * loss_recon_mat_l1
+ if "kl" in self.weights:
+ loss_total += self.weights.kl * loss_kl
+
+ loss_dict["loss_total"] = loss_total
+ return loss_total, loss_dict
+
+class VAESepLoss(nn.Module):
+ def __init__(self, weights):
+ super().__init__()
+ self.weights = weights
+
+ def forward(self, inputs, preds, iteration=None):
+ loss_dict = {"loss_total": 0.0}
+ target = inputs['gt']
+ recon = preds['recon']
+ posterior = preds['posterior']
+ recon_diff = th.abs(target - recon)
+ loss_recon_sdf_l1 = th.mean(recon_diff[:, 0:1, ...])
+ loss_recon_rgb_l1 = th.mean(recon_diff[:, 1:4, ...])
+ loss_recon_mat_l1 = th.mean(recon_diff[:, 4:6, ...])
+ loss_kl = posterior.kl().mean()
+ loss_dict.update(loss_sdf_l1=loss_recon_sdf_l1, loss_rgb_l1=loss_recon_rgb_l1, loss_mat_l1=loss_recon_mat_l1, loss_kl=loss_kl)
+ loss_total = self.weights.sdf * loss_recon_sdf_l1 + self.weights.rgb * loss_recon_rgb_l1 + self.weights.mat * loss_recon_mat_l1
+ if "kl" in self.weights:
+ loss_total += self.weights.kl * loss_kl
+
+ loss_dict["loss_total"] = loss_total
+ return loss_total, loss_dict
+
+class VAELoss(nn.Module):
+ def __init__(self, weights):
+ super().__init__()
+ self.weights = weights
+
+ def forward(self, inputs, preds, iteration=None):
+ loss_dict = {"loss_total": 0.0}
+ target = inputs['gt']
+ recon = preds['recon']
+ posterior = preds['posterior']
+ loss_recon_l1 = th.mean(th.abs(target - recon))
+ loss_kl = posterior.kl().mean()
+ loss_dict.update(loss_recon_l1=loss_recon_l1, loss_kl=loss_kl)
+ loss_total = self.weights.recon * loss_recon_l1 + self.weights.kl * loss_kl
+
+ loss_dict["loss_total"] = loss_total
+ return loss_total, loss_dict
+
+class PrimSDFLoss(nn.Module):
+ def __init__(self, weights, shape_opt_steps=2000, tex_opt_steps=6000):
+ super().__init__()
+ self.weights = weights
+ self.shape_opt_steps = shape_opt_steps
+ self.tex_opt_steps = tex_opt_steps
+
+ def forward(self, inputs, preds, iteration=None):
+ loss_dict = {"loss_total": 0.0}
+
+ if iteration < self.shape_opt_steps:
+ target_sdf = inputs['sdf']
+ sdf = preds['sdf']
+ loss_sdf_l1 = th.mean(th.abs(sdf - target_sdf))
+ loss_dict.update(loss_sdf_l1=loss_sdf_l1)
+ loss_total = self.weights.sdf_l1 * loss_sdf_l1
+
+ prim_scale = preds["prim_scale"]
+ # we use 1/scale instead of the original 100/scale as our scale is normalized to [-1, 1] cube
+ if "vol_sum" in self.weights:
+ loss_prim_vol_sum = th.mean(th.sum(th.prod(1 / prim_scale, dim=-1), dim=-1))
+ loss_dict.update(loss_prim_vol_sum=loss_prim_vol_sum)
+ loss_total += self.weights.vol_sum * loss_prim_vol_sum
+
+ if iteration >= self.shape_opt_steps and iteration < self.tex_opt_steps:
+ target_tex = inputs['tex']
+ tex = preds['tex']
+ loss_tex_l1 = th.mean(th.abs(tex - target_tex))
+ loss_dict.update(loss_tex_l1=loss_tex_l1)
+
+ loss_total = (
+ self.weights.rgb_l1 * loss_tex_l1
+ )
+ if "mat_l1" in self.weights:
+ target_mat = inputs['mat']
+ mat = preds['mat']
+ loss_mat_l1 = th.mean(th.abs(mat - target_mat))
+ loss_dict.update(loss_mat_l1=loss_mat_l1)
+ loss_total += self.weights.mat_l1 * loss_mat_l1
+
+ if "grad_l2" in self.weights:
+ loss_grad_l2 = th.mean((preds["grad"] - inputs["grad"]) ** 2)
+ loss_total += self.weights.grad_l2 * loss_grad_l2
+ loss_dict.update(loss_grad_l2=loss_grad_l2)
+
+ loss_dict["loss_total"] = loss_total
+ return loss_total, loss_dict
+
+
+class TotalMVPLoss(nn.Module):
+ def __init__(self, weights, assets=None):
+ super().__init__()
+
+ self.weights = weights
+
+ if "vgg" in self.weights:
+ self.vgg_loss = VGGLossMasked()
+
+ def forward(self, inputs, preds, iteration=None):
+
+ loss_dict = {"loss_total": 0.0}
+
+ B = inputs["image"].shape
+
+ # rgb
+ target_rgb = inputs["image"].permute(0, 2, 3, 1)
+ # removing the mask
+ target_rgb = target_rgb * inputs["image_mask"][:, 0, :, :, np.newaxis]
+
+ rgb = preds["rgb"]
+ loss_rgb_mse = th.mean(((rgb - target_rgb) / 16.0) ** 2.0)
+ loss_dict.update(loss_rgb_mse=loss_rgb_mse)
+
+ alpha = preds["alpha"]
+
+ # mask loss
+ target_mask = inputs["image_mask"][:, 0].to(th.float32)
+ loss_mask_mae = th.mean((target_mask - alpha).abs())
+ loss_dict.update(loss_mask_mae=loss_mask_mae)
+
+ B = alpha.shape[0]
+
+ # beta prior on opacity
+ loss_alpha_prior = th.mean(
+ th.log(0.1 + alpha.reshape(B, -1))
+ + th.log(0.1 + 1.0 - alpha.reshape(B, -1))
+ - -2.20727
+ )
+ loss_dict.update(loss_alpha_prior=loss_alpha_prior)
+
+ prim_scale = preds["prim_scale"]
+ loss_prim_vol_sum = th.mean(th.sum(th.prod(100.0 / prim_scale, dim=-1), dim=-1))
+ loss_dict.update(loss_prim_vol_sum=loss_prim_vol_sum)
+
+ loss_total = (
+ self.weights.rgb_mse * loss_rgb_mse
+ + self.weights.mask_mae * loss_mask_mae
+ + self.weights.alpha_prior * loss_alpha_prior
+ + self.weights.prim_vol_sum * loss_prim_vol_sum
+ )
+
+ if "embs_l2" in self.weights:
+ loss_embs_l2 = th.sum(th.norm(preds["embs"], dim=1))
+ loss_total += self.weights.embs_l2 * loss_embs_l2
+ loss_dict.update(loss_embs_l2=loss_embs_l2)
+
+ if "vgg" in self.weights:
+ loss_vgg = self.vgg_loss(
+ rgb.permute(0, 3, 1, 2),
+ target_rgb.permute(0, 3, 1, 2),
+ inputs["image_mask"],
+ )
+ loss_total += self.weights.vgg * loss_vgg
+ loss_dict.update(loss_vgg=loss_vgg)
+
+ if "prim_scale_var" in self.weights:
+ log_prim_scale = th.log(prim_scale)
+ # NOTE: should we detach this?
+ log_prim_scale_mean = th.mean(log_prim_scale, dim=1, keepdim=True)
+ loss_prim_scale_var = th.mean((log_prim_scale - log_prim_scale_mean) ** 2.0)
+ loss_total += self.weights.prim_scale_var * loss_prim_scale_var
+ loss_dict.update(loss_prim_scale_var=loss_prim_scale_var)
+
+ loss_dict["loss_total"] = loss_total
+
+ return loss_total, loss_dict
+
+
+def process_losses(loss_dict, reduce=True, detach=True):
+ """Preprocess the dict of losses outputs."""
+ result = {
+ k.replace("loss_", ""): v for k, v in loss_dict.items() if k.startswith("loss_")
+ }
+ if detach:
+ result = {k: v.detach() for k, v in result.items()}
+ if reduce:
+ result = {k: float(v.mean().item()) for k, v in result.items()}
+ return result
diff --git a/dva/mvp/extensions/mvpraymarch/bvh.cu b/dva/mvp/extensions/mvpraymarch/bvh.cu
new file mode 100644
index 0000000000000000000000000000000000000000..d203b40eddb12631d92c41d6051824a43da11236
--- /dev/null
+++ b/dva/mvp/extensions/mvpraymarch/bvh.cu
@@ -0,0 +1,292 @@
+// Copyright (c) Meta Platforms, Inc. and affiliates.
+// All rights reserved.
+//
+// This source code is licensed under the license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include
+#include
+#include
+#include