Delete lrm
Browse files- lrm/__init__.py +0 -5
- lrm/cam_utils.py +0 -138
- lrm/inferrer.py +0 -232
- lrm/models/__init__.py +0 -5
- lrm/models/encoders/__init__.py +0 -5
- lrm/models/encoders/dino_wrapper2.py +0 -51
- lrm/models/generator.py +0 -87
- lrm/models/rendering/__init__.py +0 -5
- lrm/models/rendering/synthesizer_part.py +0 -194
- lrm/models/rendering/utils/__init__.py +0 -14
- lrm/models/rendering/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- lrm/models/rendering/utils/__pycache__/math_utils.cpython-310.pyc +0 -0
- lrm/models/rendering/utils/__pycache__/ray_marcher.cpython-310.pyc +0 -0
- lrm/models/rendering/utils/__pycache__/ray_sampler_part.cpython-310.pyc +0 -0
- lrm/models/rendering/utils/__pycache__/renderer.cpython-310.pyc +0 -0
- lrm/models/rendering/utils/math_utils.py +0 -123
- lrm/models/rendering/utils/ray_marcher.py +0 -73
- lrm/models/rendering/utils/ray_sampler_part.py +0 -94
- lrm/models/rendering/utils/renderer.py +0 -314
- lrm/models/transformer.py +0 -135
lrm/__init__.py
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
|
|
|
|
lrm/cam_utils.py
DELETED
@@ -1,138 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
|
8 |
-
import torch
|
9 |
-
import numpy as np
|
10 |
-
import math
|
11 |
-
|
12 |
-
"""
|
13 |
-
R: (N, 3, 3)
|
14 |
-
T: (N, 3)
|
15 |
-
E: (N, 4, 4)
|
16 |
-
vector: (N, 3)
|
17 |
-
"""
|
18 |
-
|
19 |
-
|
20 |
-
def compose_extrinsic_R_T(R: torch.Tensor, T: torch.Tensor):
|
21 |
-
"""
|
22 |
-
Compose the standard form extrinsic matrix from R and T.
|
23 |
-
Batched I/O.
|
24 |
-
"""
|
25 |
-
RT = torch.cat((R, T.unsqueeze(-1)), dim=-1)
|
26 |
-
return compose_extrinsic_RT(RT)
|
27 |
-
|
28 |
-
|
29 |
-
def compose_extrinsic_RT(RT: torch.Tensor):
|
30 |
-
"""
|
31 |
-
Compose the standard form extrinsic matrix from RT.
|
32 |
-
Batched I/O.
|
33 |
-
"""
|
34 |
-
return torch.cat([
|
35 |
-
RT,
|
36 |
-
torch.tensor([[[0, 0, 0, 1]]], dtype=torch.float32).repeat(RT.shape[0], 1, 1).to(RT.device)
|
37 |
-
], dim=1)
|
38 |
-
|
39 |
-
|
40 |
-
def decompose_extrinsic_R_T(E: torch.Tensor):
|
41 |
-
"""
|
42 |
-
Decompose the standard extrinsic matrix into R and T.
|
43 |
-
Batched I/O.
|
44 |
-
"""
|
45 |
-
RT = decompose_extrinsic_RT(E)
|
46 |
-
return RT[:, :, :3], RT[:, :, 3]
|
47 |
-
|
48 |
-
|
49 |
-
def decompose_extrinsic_RT(E: torch.Tensor):
|
50 |
-
"""
|
51 |
-
Decompose the standard extrinsic matrix into RT.
|
52 |
-
Batched I/O.
|
53 |
-
"""
|
54 |
-
return E[:, :3, :]
|
55 |
-
|
56 |
-
|
57 |
-
def get_normalized_camera_intrinsics(intrinsics: torch.Tensor):
|
58 |
-
"""
|
59 |
-
intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
|
60 |
-
Return batched fx, fy, cx, cy
|
61 |
-
"""
|
62 |
-
fx, fy = intrinsics[:, 0, 0], intrinsics[:, 0, 1]
|
63 |
-
cx, cy = intrinsics[:, 1, 0], intrinsics[:, 1, 1]
|
64 |
-
width, height = intrinsics[:, 2, 0], intrinsics[:, 2, 1]
|
65 |
-
fx, fy = fx / width, fy / height
|
66 |
-
cx, cy = cx / width, cy / height
|
67 |
-
return fx, fy, cx, cy
|
68 |
-
|
69 |
-
|
70 |
-
def build_camera_principle(RT: torch.Tensor, intrinsics: torch.Tensor):
|
71 |
-
"""
|
72 |
-
RT: (N, 3, 4)
|
73 |
-
intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
|
74 |
-
"""
|
75 |
-
fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics)
|
76 |
-
return torch.cat([
|
77 |
-
RT.reshape(-1, 12),
|
78 |
-
fx.unsqueeze(-1), fy.unsqueeze(-1), cx.unsqueeze(-1), cy.unsqueeze(-1),
|
79 |
-
], dim=-1)
|
80 |
-
|
81 |
-
|
82 |
-
def build_camera_standard(RT: torch.Tensor, intrinsics: torch.Tensor):
|
83 |
-
"""
|
84 |
-
RT: (N, 3, 4)
|
85 |
-
intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
|
86 |
-
"""
|
87 |
-
E = compose_extrinsic_RT(RT)
|
88 |
-
fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics)
|
89 |
-
I = torch.stack([
|
90 |
-
torch.stack([fx, torch.zeros_like(fx), cx], dim=-1),
|
91 |
-
torch.stack([torch.zeros_like(fy), fy, cy], dim=-1),
|
92 |
-
torch.tensor([[0, 0, 1]], dtype=torch.float32, device=RT.device).repeat(RT.shape[0], 1),
|
93 |
-
], dim=1)
|
94 |
-
return torch.cat([
|
95 |
-
E.reshape(-1, 16),
|
96 |
-
I.reshape(-1, 9),
|
97 |
-
], dim=-1)
|
98 |
-
|
99 |
-
|
100 |
-
def center_looking_at_camera_pose(camera_position: torch.Tensor, look_at: torch.Tensor = None, up_world: torch.Tensor = None):
|
101 |
-
"""
|
102 |
-
camera_position: (M, 3)
|
103 |
-
look_at: (3)
|
104 |
-
up_world: (3)
|
105 |
-
return: (M, 3, 4)
|
106 |
-
"""
|
107 |
-
# by default, looking at the origin and world up is pos-z
|
108 |
-
if look_at is None:
|
109 |
-
look_at = torch.tensor([0, 0, 0], dtype=torch.float32)
|
110 |
-
if up_world is None:
|
111 |
-
up_world = torch.tensor([0, 0, 1], dtype=torch.float32)
|
112 |
-
look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1)
|
113 |
-
up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1)
|
114 |
-
|
115 |
-
z_axis = camera_position - look_at
|
116 |
-
z_axis = z_axis / z_axis.norm(dim=-1, keepdim=True)
|
117 |
-
x_axis = torch.cross(up_world, z_axis)
|
118 |
-
x_axis = x_axis / x_axis.norm(dim=-1, keepdim=True)
|
119 |
-
y_axis = torch.cross(z_axis, x_axis)
|
120 |
-
y_axis = y_axis / y_axis.norm(dim=-1, keepdim=True)
|
121 |
-
extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1)
|
122 |
-
return extrinsics
|
123 |
-
|
124 |
-
def get_surrounding_views(M, radius, elevation):
|
125 |
-
# convert spherical coordinates (radius, azimuth, elevation) to Cartesian coordinates (x, y, z).
|
126 |
-
camera_positions = []
|
127 |
-
rand_theta= np.random.uniform(0, np.pi/180)
|
128 |
-
elevation = math.radians(elevation)
|
129 |
-
for i in range(M):
|
130 |
-
theta = 2 * math.pi * i / M + rand_theta
|
131 |
-
x = radius * math.cos(theta) * math.cos(elevation)
|
132 |
-
y = radius * math.sin(theta) * math.cos(elevation)
|
133 |
-
z = radius * math.sin(elevation)
|
134 |
-
camera_positions.append([x, y, z])
|
135 |
-
camera_positions = torch.tensor(camera_positions, dtype=torch.float32)
|
136 |
-
extrinsics = center_looking_at_camera_pose(camera_positions)
|
137 |
-
|
138 |
-
return extrinsics
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lrm/inferrer.py
DELETED
@@ -1,232 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import math
|
3 |
-
import os
|
4 |
-
import imageio
|
5 |
-
import mcubes
|
6 |
-
import trimesh
|
7 |
-
import numpy as np
|
8 |
-
import argparse
|
9 |
-
from torchvision.utils import save_image
|
10 |
-
from PIL import Image
|
11 |
-
import glob
|
12 |
-
from .models.generator import LRMGenerator # Make sure this import is correct
|
13 |
-
from .cam_utils import build_camera_principle, build_camera_standard, center_looking_at_camera_pose # Make sure this import is correct
|
14 |
-
from functools import partial
|
15 |
-
from rembg import remove, new_session
|
16 |
-
from kiui.op import recenter
|
17 |
-
import kiui
|
18 |
-
|
19 |
-
class LRMInferrer:
|
20 |
-
def __init__(self, model_name: str, resume: str):
|
21 |
-
print("Initializing LRMInferrer")
|
22 |
-
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
23 |
-
_model_kwargs = {'camera_embed_dim': 1024, 'rendering_samples_per_ray': 128, 'transformer_dim': 1024, 'transformer_layers': 16, 'transformer_heads': 16, 'triplane_low_res': 32, 'triplane_high_res': 64, 'triplane_dim': 80, 'encoder_freeze': False}
|
24 |
-
|
25 |
-
self.model = self._build_model(_model_kwargs).eval().to(self.device)
|
26 |
-
checkpoint = torch.load(resume, map_location='cpu')
|
27 |
-
state_dict = checkpoint['model_state_dict']
|
28 |
-
self.model.load_state_dict(state_dict)
|
29 |
-
del checkpoint, state_dict
|
30 |
-
torch.cuda.empty_cache()
|
31 |
-
|
32 |
-
def __enter__(self):
|
33 |
-
print("Entering context")
|
34 |
-
return self
|
35 |
-
|
36 |
-
def __exit__(self, exc_type, exc_val, exc_tb):
|
37 |
-
print("Exiting context")
|
38 |
-
if exc_type:
|
39 |
-
print(f"Exception type: {exc_type}")
|
40 |
-
print(f"Exception value: {exc_val}")
|
41 |
-
print(f"Traceback: {exc_tb}")
|
42 |
-
|
43 |
-
def _build_model(self, model_kwargs):
|
44 |
-
print("Building model")
|
45 |
-
model = LRMGenerator(**model_kwargs).to(self.device)
|
46 |
-
print("Loaded model from checkpoint")
|
47 |
-
return model
|
48 |
-
|
49 |
-
@staticmethod
|
50 |
-
def get_surrounding_views(M, radius, elevation):
|
51 |
-
camera_positions = []
|
52 |
-
rand_theta = np.random.uniform(0, np.pi/180)
|
53 |
-
elevation = math.radians(elevation)
|
54 |
-
for i in range(M):
|
55 |
-
theta = 2 * math.pi * i / M + rand_theta
|
56 |
-
x = radius * math.cos(theta) * math.cos(elevation)
|
57 |
-
y = radius * math.sin(theta) * math.cos(elevation)
|
58 |
-
z = radius * math.sin(elevation)
|
59 |
-
camera_positions.append([x, y, z])
|
60 |
-
camera_positions = torch.tensor(camera_positions, dtype=torch.float32)
|
61 |
-
extrinsics = center_looking_at_camera_pose(camera_positions)
|
62 |
-
return extrinsics
|
63 |
-
|
64 |
-
@staticmethod
|
65 |
-
def _default_intrinsics():
|
66 |
-
fx = fy = 384
|
67 |
-
cx = cy = 256
|
68 |
-
w = h = 512
|
69 |
-
intrinsics = torch.tensor([
|
70 |
-
[fx, fy],
|
71 |
-
[cx, cy],
|
72 |
-
[w, h],
|
73 |
-
], dtype=torch.float32)
|
74 |
-
return intrinsics
|
75 |
-
|
76 |
-
def _default_source_camera(self, batch_size: int = 1):
|
77 |
-
dist_to_center = 1.5
|
78 |
-
canonical_camera_extrinsics = torch.tensor([[
|
79 |
-
[0, 0, 1, 1],
|
80 |
-
[1, 0, 0, 0],
|
81 |
-
[0, 1, 0, 0],
|
82 |
-
]], dtype=torch.float32)
|
83 |
-
canonical_camera_intrinsics = self._default_intrinsics().unsqueeze(0)
|
84 |
-
source_camera = build_camera_principle(canonical_camera_extrinsics, canonical_camera_intrinsics)
|
85 |
-
return source_camera.repeat(batch_size, 1)
|
86 |
-
|
87 |
-
def _default_render_cameras(self, batch_size: int = 1):
|
88 |
-
render_camera_extrinsics = self.get_surrounding_views(160, 1.5, 0)
|
89 |
-
render_camera_intrinsics = self._default_intrinsics().unsqueeze(0).repeat(render_camera_extrinsics.shape[0], 1, 1)
|
90 |
-
render_cameras = build_camera_standard(render_camera_extrinsics, render_camera_intrinsics)
|
91 |
-
return render_cameras.unsqueeze(0).repeat(batch_size, 1, 1)
|
92 |
-
|
93 |
-
@staticmethod
|
94 |
-
def images_to_video(images, output_path, fps, verbose=False):
|
95 |
-
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
96 |
-
frames = []
|
97 |
-
for i in range(images.shape[0]):
|
98 |
-
frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
99 |
-
assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
|
100 |
-
f"Frame shape mismatch: {frame.shape} vs {images.shape}"
|
101 |
-
assert frame.min() >= 0 and frame.max() <= 255, \
|
102 |
-
f"Frame value out of range: {frame.min()} ~ {frame.max()}"
|
103 |
-
frames.append(frame)
|
104 |
-
imageio.mimwrite(output_path, np.stack(frames), fps=fps)
|
105 |
-
if verbose:
|
106 |
-
print(f"Saved video to {output_path}")
|
107 |
-
|
108 |
-
def infer_single(self, image: torch.Tensor, render_size: int, mesh_size: int, export_video: bool, export_mesh: bool):
|
109 |
-
print("infer_single called")
|
110 |
-
mesh_thres = 1.0
|
111 |
-
chunk_size = 2
|
112 |
-
batch_size = 1
|
113 |
-
|
114 |
-
source_camera = self._default_source_camera(batch_size).to(self.device)
|
115 |
-
render_cameras = self._default_render_cameras(batch_size).to(self.device)
|
116 |
-
|
117 |
-
with torch.no_grad():
|
118 |
-
planes = self.model.forward(image, source_camera)
|
119 |
-
results = {}
|
120 |
-
|
121 |
-
if export_video:
|
122 |
-
print("Starting export_video")
|
123 |
-
frames = []
|
124 |
-
for i in range(0, render_cameras.shape[1], chunk_size):
|
125 |
-
print(f"Processing chunk {i} to {i + chunk_size}")
|
126 |
-
frames.append(
|
127 |
-
self.model.synthesizer(
|
128 |
-
planes,
|
129 |
-
render_cameras[:, i:i+chunk_size],
|
130 |
-
render_size,
|
131 |
-
render_size,
|
132 |
-
0,
|
133 |
-
0
|
134 |
-
)
|
135 |
-
)
|
136 |
-
frames = {
|
137 |
-
k: torch.cat([r[k] for r in frames], dim=1)
|
138 |
-
for k in frames[0].keys()
|
139 |
-
}
|
140 |
-
results.update({
|
141 |
-
'frames': frames,
|
142 |
-
})
|
143 |
-
print("Finished export_video")
|
144 |
-
|
145 |
-
if export_mesh:
|
146 |
-
print("Starting export_mesh")
|
147 |
-
grid_out = self.model.synthesizer.forward_grid(
|
148 |
-
planes=planes,
|
149 |
-
grid_size=mesh_size,
|
150 |
-
)
|
151 |
-
vtx, faces = mcubes.marching_cubes(grid_out['sigma'].float().squeeze(0).squeeze(-1).cpu().numpy(), mesh_thres)
|
152 |
-
vtx = vtx / (mesh_size - 1) * 2 - 1
|
153 |
-
vtx_tensor = torch.tensor(vtx, dtype=torch.float32, device=self.device).unsqueeze(0)
|
154 |
-
vtx_colors = self.model.synthesizer.forward_points(planes, vtx_tensor)['rgb'].float().squeeze(0).cpu().numpy()
|
155 |
-
vtx_colors = (vtx_colors * 255).astype(np.uint8)
|
156 |
-
mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors)
|
157 |
-
results.update({
|
158 |
-
'mesh': mesh,
|
159 |
-
})
|
160 |
-
print("Finished export_mesh")
|
161 |
-
|
162 |
-
return results
|
163 |
-
|
164 |
-
def infer(self, source_image: str, dump_path: str, source_size: int, render_size: int, mesh_size: int, export_video: bool, export_mesh: bool):
|
165 |
-
print("infer called")
|
166 |
-
session = new_session("isnet-general-use")
|
167 |
-
rembg_remove = partial(remove, session=session)
|
168 |
-
image_name = os.path.basename(source_image)
|
169 |
-
uid = image_name.split('.')[0]
|
170 |
-
|
171 |
-
image = kiui.read_image(source_image, mode='uint8')
|
172 |
-
image = rembg_remove(image)
|
173 |
-
mask = rembg_remove(image, only_mask=True)
|
174 |
-
image = recenter(image, mask, border_ratio=0.20)
|
175 |
-
os.makedirs(dump_path, exist_ok=True)
|
176 |
-
|
177 |
-
image = torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0) / 255.0
|
178 |
-
if image.shape[1] == 4:
|
179 |
-
image = image[:, :3, ...] * image[:, 3:, ...] + (1 - image[:, 3:, ...])
|
180 |
-
image = torch.nn.functional.interpolate(image, size=(source_size, source_size), mode='bicubic', align_corners=True)
|
181 |
-
image = torch.clamp(image, 0, 1)
|
182 |
-
save_image(image, os.path.join(dump_path, f'{uid}.png'))
|
183 |
-
|
184 |
-
results = self.infer_single(
|
185 |
-
image.cuda(),
|
186 |
-
render_size=render_size,
|
187 |
-
mesh_size=mesh_size,
|
188 |
-
export_video=export_video,
|
189 |
-
export_mesh=export_mesh,
|
190 |
-
)
|
191 |
-
|
192 |
-
if 'frames' in results:
|
193 |
-
renderings = results['frames']
|
194 |
-
for k, v in renderings.items():
|
195 |
-
if k == 'images_rgb':
|
196 |
-
self.images_to_video(
|
197 |
-
v[0],
|
198 |
-
os.path.join(dump_path, f'{uid}.mp4'),
|
199 |
-
fps=40,
|
200 |
-
)
|
201 |
-
print(f"Export video success to {dump_path}")
|
202 |
-
|
203 |
-
if 'mesh' in results:
|
204 |
-
mesh = results['mesh']
|
205 |
-
mesh.export(os.path.join(dump_path, f'{uid}.obj'), 'obj')
|
206 |
-
|
207 |
-
if __name__ == '__main__':
|
208 |
-
parser = argparse.ArgumentParser()
|
209 |
-
parser.add_argument('--model_name', type=str, default='lrm-base-obj-v1')
|
210 |
-
parser.add_argument('--source_path', type=str, default='./assets/cat.png')
|
211 |
-
parser.add_argument('--dump_path', type=str, default='./results/single_image')
|
212 |
-
parser.add_argument('--source_size', type=int, default=512)
|
213 |
-
parser.add_argument('--render_size', type=int, default=384)
|
214 |
-
parser.add_argument('--mesh_size', type=int, default=512)
|
215 |
-
parser.add_argument('--export_video', action='store_true')
|
216 |
-
parser.add_argument('--export_mesh', action='store_true')
|
217 |
-
parser.add_argument('--resume', type=str, required=True, help='Path to a checkpoint to resume training from')
|
218 |
-
args = parser.parse_args()
|
219 |
-
|
220 |
-
with LRMInferrer(model_name=args.model_name, resume=args.resume) as inferrer:
|
221 |
-
with torch.autocast(device_type="cuda", cache_enabled=False, dtype=torch.float32):
|
222 |
-
print("Start inference for image:", args.source_path)
|
223 |
-
inferrer.infer(
|
224 |
-
source_image=args.source_path,
|
225 |
-
dump_path=args.dump_path,
|
226 |
-
source_size=args.source_size,
|
227 |
-
render_size=args.render_size,
|
228 |
-
mesh_size=args.mesh_size,
|
229 |
-
export_video=args.export_video,
|
230 |
-
export_mesh=args.export_mesh,
|
231 |
-
)
|
232 |
-
print("Finished inference for image:", args.source_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lrm/models/__init__.py
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
|
|
|
|
lrm/models/encoders/__init__.py
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
|
|
|
|
lrm/models/encoders/dino_wrapper2.py
DELETED
@@ -1,51 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
|
8 |
-
import torch.nn as nn
|
9 |
-
from transformers import ViTImageProcessor, ViTModel, AutoImageProcessor, AutoModel, Dinov2Model
|
10 |
-
|
11 |
-
class DinoWrapper(nn.Module):
|
12 |
-
"""
|
13 |
-
Dino v1 wrapper using huggingface transformer implementation.
|
14 |
-
"""
|
15 |
-
def __init__(self, model_name: str, freeze: bool = True):
|
16 |
-
super().__init__()
|
17 |
-
self.model, self.processor = self._build_dino(model_name)
|
18 |
-
if freeze:
|
19 |
-
self._freeze()
|
20 |
-
|
21 |
-
def forward(self, image):
|
22 |
-
# image: [N, C, H, W], on cpu
|
23 |
-
# RGB image with [0,1] scale and properly sized
|
24 |
-
inputs = self.processor(images=image.float(), return_tensors="pt", do_rescale=False, do_resize=False).to(self.model.device)
|
25 |
-
# This resampling of positional embedding uses bicubic interpolation
|
26 |
-
outputs = self.model(**inputs)
|
27 |
-
last_hidden_states = outputs.last_hidden_state
|
28 |
-
return last_hidden_states
|
29 |
-
|
30 |
-
def _freeze(self):
|
31 |
-
print(f"======== Freezing DinoWrapper ========")
|
32 |
-
self.model.eval()
|
33 |
-
for name, param in self.model.named_parameters():
|
34 |
-
param.requires_grad = False
|
35 |
-
|
36 |
-
@staticmethod
|
37 |
-
def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5):
|
38 |
-
import requests
|
39 |
-
try:
|
40 |
-
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
|
41 |
-
processor.do_center_crop = False
|
42 |
-
model = AutoModel.from_pretrained('facebook/dinov2-base')
|
43 |
-
return model, processor
|
44 |
-
except requests.exceptions.ProxyError as err:
|
45 |
-
if proxy_error_retries > 0:
|
46 |
-
print(f"Huggingface ProxyError: Retrying in {proxy_error_cooldown} seconds...")
|
47 |
-
import time
|
48 |
-
time.sleep(proxy_error_cooldown)
|
49 |
-
return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown)
|
50 |
-
else:
|
51 |
-
raise err
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lrm/models/generator.py
DELETED
@@ -1,87 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
|
8 |
-
import torch.nn as nn
|
9 |
-
|
10 |
-
from .encoders.dino_wrapper2 import DinoWrapper
|
11 |
-
from .transformer import TriplaneTransformer
|
12 |
-
from .rendering.synthesizer_part import TriplaneSynthesizer
|
13 |
-
|
14 |
-
|
15 |
-
class CameraEmbedder(nn.Module):
|
16 |
-
"""
|
17 |
-
Embed camera features to a high-dimensional vector.
|
18 |
-
|
19 |
-
Reference:
|
20 |
-
DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L27
|
21 |
-
"""
|
22 |
-
def __init__(self, raw_dim: int, embed_dim: int):
|
23 |
-
super().__init__()
|
24 |
-
self.mlp = nn.Sequential(
|
25 |
-
nn.Linear(raw_dim, embed_dim),
|
26 |
-
nn.SiLU(),
|
27 |
-
nn.Linear(embed_dim, embed_dim),
|
28 |
-
)
|
29 |
-
|
30 |
-
def forward(self, x):
|
31 |
-
return self.mlp(x)
|
32 |
-
|
33 |
-
|
34 |
-
class LRMGenerator(nn.Module):
|
35 |
-
"""
|
36 |
-
Full model of the large reconstruction model.
|
37 |
-
"""
|
38 |
-
def __init__(self, camera_embed_dim: int, rendering_samples_per_ray: int,
|
39 |
-
transformer_dim: int, transformer_layers: int, transformer_heads: int,
|
40 |
-
triplane_low_res: int, triplane_high_res: int, triplane_dim: int,
|
41 |
-
encoder_freeze: bool = True, encoder_model_name: str = 'facebook/dinov2-base', encoder_feat_dim: int = 768):
|
42 |
-
super().__init__()
|
43 |
-
|
44 |
-
# attributes
|
45 |
-
self.encoder_feat_dim = encoder_feat_dim
|
46 |
-
self.camera_embed_dim = camera_embed_dim
|
47 |
-
|
48 |
-
# modules
|
49 |
-
self.encoder = DinoWrapper(
|
50 |
-
model_name=encoder_model_name,
|
51 |
-
freeze=encoder_freeze,
|
52 |
-
)
|
53 |
-
self.camera_embedder = CameraEmbedder(
|
54 |
-
raw_dim=12+4, embed_dim=camera_embed_dim,
|
55 |
-
)
|
56 |
-
self.transformer = TriplaneTransformer(
|
57 |
-
inner_dim=transformer_dim, num_layers=transformer_layers, num_heads=transformer_heads,
|
58 |
-
image_feat_dim=encoder_feat_dim,
|
59 |
-
camera_embed_dim=camera_embed_dim,
|
60 |
-
triplane_low_res=triplane_low_res, triplane_high_res=triplane_high_res, triplane_dim=triplane_dim,
|
61 |
-
)
|
62 |
-
self.synthesizer = TriplaneSynthesizer(
|
63 |
-
triplane_dim=triplane_dim, samples_per_ray=rendering_samples_per_ray,
|
64 |
-
)
|
65 |
-
|
66 |
-
def forward(self, image, camera):
|
67 |
-
# image: [N, C_img, H_img, W_img]
|
68 |
-
# camera: [N, D_cam_raw]
|
69 |
-
assert image.shape[0] == camera.shape[0], "Batch size mismatch for image and camera"
|
70 |
-
N = image.shape[0]
|
71 |
-
|
72 |
-
# encode image
|
73 |
-
image_feats = self.encoder(image)
|
74 |
-
assert image_feats.shape[-1] == self.encoder_feat_dim, \
|
75 |
-
f"Feature dimension mismatch: {image_feats.shape[-1]} vs {self.encoder_feat_dim}"
|
76 |
-
|
77 |
-
# embed camera
|
78 |
-
camera_embeddings = self.camera_embedder(camera)
|
79 |
-
assert camera_embeddings.shape[-1] == self.camera_embed_dim, \
|
80 |
-
f"Feature dimension mismatch: {camera_embeddings.shape[-1]} vs {self.camera_embed_dim}"
|
81 |
-
|
82 |
-
# transformer generating planes
|
83 |
-
planes = self.transformer(image_feats, camera_embeddings)
|
84 |
-
assert planes.shape[0] == N, "Batch size mismatch for planes"
|
85 |
-
assert planes.shape[1] == 3, "Planes should have 3 channels"
|
86 |
-
return planes
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lrm/models/rendering/__init__.py
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
|
|
|
|
lrm/models/rendering/synthesizer_part.py
DELETED
@@ -1,194 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
|
8 |
-
import itertools
|
9 |
-
import torch
|
10 |
-
import torch.nn as nn
|
11 |
-
|
12 |
-
from .utils.renderer import ImportanceRenderer
|
13 |
-
from .utils.ray_sampler_part import RaySampler
|
14 |
-
|
15 |
-
|
16 |
-
class OSGDecoder(nn.Module):
|
17 |
-
"""
|
18 |
-
Triplane decoder that gives RGB and sigma values from sampled features.
|
19 |
-
Using ReLU here instead of Softplus in the original implementation.
|
20 |
-
|
21 |
-
Reference:
|
22 |
-
EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112
|
23 |
-
"""
|
24 |
-
def __init__(self, n_features: int,
|
25 |
-
hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU):
|
26 |
-
super().__init__()
|
27 |
-
self.net = nn.Sequential(
|
28 |
-
nn.Linear(3 * n_features, hidden_dim),
|
29 |
-
activation(),
|
30 |
-
*itertools.chain(*[[
|
31 |
-
nn.Linear(hidden_dim, hidden_dim),
|
32 |
-
activation(),
|
33 |
-
] for _ in range(num_layers - 2)]),
|
34 |
-
nn.Linear(hidden_dim, 1 + 3),
|
35 |
-
)
|
36 |
-
# init all bias to zero
|
37 |
-
for m in self.modules():
|
38 |
-
if isinstance(m, nn.Linear):
|
39 |
-
nn.init.zeros_(m.bias)
|
40 |
-
|
41 |
-
def forward(self, sampled_features, ray_directions):
|
42 |
-
# Aggregate features by mean
|
43 |
-
# sampled_features = sampled_features.mean(1)
|
44 |
-
# Aggregate features by concatenation
|
45 |
-
_N, n_planes, _M, _C = sampled_features.shape
|
46 |
-
sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)
|
47 |
-
x = sampled_features
|
48 |
-
|
49 |
-
N, M, C = x.shape
|
50 |
-
x = x.contiguous().view(N*M, C)
|
51 |
-
|
52 |
-
x = self.net(x)
|
53 |
-
x = x.view(N, M, -1)
|
54 |
-
rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
|
55 |
-
sigma = x[..., 0:1]
|
56 |
-
|
57 |
-
return {'rgb': rgb, 'sigma': sigma}
|
58 |
-
|
59 |
-
|
60 |
-
class TriplaneSynthesizer(nn.Module):
|
61 |
-
"""
|
62 |
-
Synthesizer that renders a triplane volume with planes and a camera.
|
63 |
-
|
64 |
-
Reference:
|
65 |
-
EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19
|
66 |
-
"""
|
67 |
-
|
68 |
-
DEFAULT_RENDERING_KWARGS = {
|
69 |
-
'ray_start': 'auto',
|
70 |
-
'ray_end': 'auto',
|
71 |
-
'box_warp': 2.,
|
72 |
-
'white_back': True,
|
73 |
-
'disparity_space_sampling': False,
|
74 |
-
'clamp_mode': 'softplus',
|
75 |
-
'sampler_bbox_min': -1.,
|
76 |
-
'sampler_bbox_max': 1.,
|
77 |
-
}
|
78 |
-
|
79 |
-
def __init__(self, triplane_dim: int, samples_per_ray: int):
|
80 |
-
super().__init__()
|
81 |
-
|
82 |
-
# attributes
|
83 |
-
self.triplane_dim = triplane_dim
|
84 |
-
self.rendering_kwargs = {
|
85 |
-
**self.DEFAULT_RENDERING_KWARGS,
|
86 |
-
'depth_resolution': samples_per_ray // 2,
|
87 |
-
'depth_resolution_importance': samples_per_ray // 2,
|
88 |
-
}
|
89 |
-
|
90 |
-
# renderings
|
91 |
-
self.renderer = ImportanceRenderer()
|
92 |
-
self.ray_sampler = RaySampler()
|
93 |
-
|
94 |
-
# modules
|
95 |
-
self.decoder = OSGDecoder(n_features=triplane_dim)
|
96 |
-
|
97 |
-
def forward(self, planes, cameras, render_size: int, crop_size: int, start_x: int, start_y:int):
|
98 |
-
# planes: (N, 3, D', H', W')
|
99 |
-
# cameras: (N, M, D_cam)
|
100 |
-
# render_size: int
|
101 |
-
assert planes.shape[0] == cameras.shape[0], "Batch size mismatch for planes and cameras"
|
102 |
-
N, M = cameras.shape[:2]
|
103 |
-
cam2world_matrix = cameras[..., :16].view(N, M, 4, 4)
|
104 |
-
intrinsics = cameras[..., 16:25].view(N, M, 3, 3)
|
105 |
-
|
106 |
-
# Create a batch of rays for volume rendering
|
107 |
-
ray_origins, ray_directions = self.ray_sampler(
|
108 |
-
cam2world_matrix=cam2world_matrix.reshape(-1, 4, 4),
|
109 |
-
intrinsics=intrinsics.reshape(-1, 3, 3),
|
110 |
-
render_size=render_size,
|
111 |
-
crop_size = crop_size,
|
112 |
-
start_x = start_x,
|
113 |
-
start_y = start_y
|
114 |
-
)
|
115 |
-
assert N*M == ray_origins.shape[0], "Batch size mismatch for ray_origins"
|
116 |
-
assert ray_origins.dim() == 3, "ray_origins should be 3-dimensional"
|
117 |
-
# Perform volume rendering
|
118 |
-
rgb_samples, depth_samples, weights_samples = self.renderer(
|
119 |
-
planes.repeat_interleave(M, dim=0), self.decoder, ray_origins, ray_directions, self.rendering_kwargs,
|
120 |
-
)
|
121 |
-
|
122 |
-
# Reshape into 'raw' neural-rendered image
|
123 |
-
Himg = Wimg = crop_size
|
124 |
-
rgb_images = rgb_samples.permute(0, 2, 1).reshape(N, M, rgb_samples.shape[-1], Himg, Wimg).contiguous()
|
125 |
-
depth_images = depth_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg)
|
126 |
-
weight_images = weights_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg)
|
127 |
-
|
128 |
-
return {
|
129 |
-
'images_rgb': rgb_images,
|
130 |
-
'images_depth': depth_images,
|
131 |
-
'images_weight': weight_images,
|
132 |
-
}
|
133 |
-
|
134 |
-
def forward_grid(self, planes, grid_size: int, aabb: torch.Tensor = None):
|
135 |
-
# planes: (N, 3, D', H', W')
|
136 |
-
# grid_size: int
|
137 |
-
# aabb: (N, 2, 3)
|
138 |
-
if aabb is None:
|
139 |
-
aabb = torch.tensor([
|
140 |
-
[self.rendering_kwargs['sampler_bbox_min']] * 3,
|
141 |
-
[self.rendering_kwargs['sampler_bbox_max']] * 3,
|
142 |
-
], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat(planes.shape[0], 1, 1)
|
143 |
-
assert planes.shape[0] == aabb.shape[0], "Batch size mismatch for planes and aabb"
|
144 |
-
N = planes.shape[0]
|
145 |
-
|
146 |
-
# create grid points for triplane query
|
147 |
-
grid_points = []
|
148 |
-
for i in range(N):
|
149 |
-
grid_points.append(torch.stack(torch.meshgrid(
|
150 |
-
torch.linspace(aabb[i, 0, 0], aabb[i, 1, 0], grid_size, device=planes.device),
|
151 |
-
torch.linspace(aabb[i, 0, 1], aabb[i, 1, 1], grid_size, device=planes.device),
|
152 |
-
torch.linspace(aabb[i, 0, 2], aabb[i, 1, 2], grid_size, device=planes.device),
|
153 |
-
indexing='ij',
|
154 |
-
), dim=-1).reshape(-1, 3))
|
155 |
-
cube_grid = torch.stack(grid_points, dim=0).to(planes.device)
|
156 |
-
|
157 |
-
features = self.forward_points(planes, cube_grid)
|
158 |
-
|
159 |
-
# reshape into grid
|
160 |
-
features = {
|
161 |
-
k: v.reshape(N, grid_size, grid_size, grid_size, -1)
|
162 |
-
for k, v in features.items()
|
163 |
-
}
|
164 |
-
return features
|
165 |
-
|
166 |
-
def forward_points(self, planes, points: torch.Tensor, chunk_size: int = 2**20):
|
167 |
-
# planes: (N, 3, D', H', W')
|
168 |
-
# points: (N, P, 3)
|
169 |
-
N, P = points.shape[:2]
|
170 |
-
|
171 |
-
# query triplane in chunks
|
172 |
-
outs = []
|
173 |
-
for i in range(0, points.shape[1], chunk_size):
|
174 |
-
chunk_points = points[:, i:i+chunk_size]
|
175 |
-
|
176 |
-
# query triplane
|
177 |
-
chunk_out = self.renderer.run_model_activated(
|
178 |
-
planes=planes,
|
179 |
-
decoder=self.decoder,
|
180 |
-
sample_coordinates=chunk_points,
|
181 |
-
sample_directions=torch.zeros_like(chunk_points),
|
182 |
-
options=self.rendering_kwargs,
|
183 |
-
)
|
184 |
-
outs.append(chunk_out)
|
185 |
-
|
186 |
-
# concatenate the outputs
|
187 |
-
point_features = {
|
188 |
-
k: torch.cat([out[k] for out in outs], dim=1)
|
189 |
-
for k in outs[0].keys()
|
190 |
-
}
|
191 |
-
|
192 |
-
sig = point_features['sigma']
|
193 |
-
print(sig.mean(), sig.max(), sig.min())
|
194 |
-
return point_features
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lrm/models/rendering/utils/__init__.py
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
7 |
-
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
8 |
-
#
|
9 |
-
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
10 |
-
# property and proprietary rights in and to this material, related
|
11 |
-
# documentation and any modifications thereto. Any use, reproduction,
|
12 |
-
# disclosure or distribution of this material and related documentation
|
13 |
-
# without an express license agreement from NVIDIA CORPORATION or
|
14 |
-
# its affiliates is strictly prohibited.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lrm/models/rendering/utils/__pycache__/__init__.cpython-310.pyc
DELETED
Binary file (192 Bytes)
|
|
lrm/models/rendering/utils/__pycache__/math_utils.cpython-310.pyc
DELETED
Binary file (2.76 kB)
|
|
lrm/models/rendering/utils/__pycache__/ray_marcher.cpython-310.pyc
DELETED
Binary file (2.01 kB)
|
|
lrm/models/rendering/utils/__pycache__/ray_sampler_part.cpython-310.pyc
DELETED
Binary file (2.75 kB)
|
|
lrm/models/rendering/utils/__pycache__/renderer.cpython-310.pyc
DELETED
Binary file (10.5 kB)
|
|
lrm/models/rendering/utils/math_utils.py
DELETED
@@ -1,123 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
# MIT License
|
7 |
-
|
8 |
-
# Copyright (c) 2022 Petr Kellnhofer
|
9 |
-
|
10 |
-
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
11 |
-
# of this software and associated documentation files (the "Software"), to deal
|
12 |
-
# in the Software without restriction, including without limitation the rights
|
13 |
-
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
14 |
-
# copies of the Software, and to permit persons to whom the Software is
|
15 |
-
# furnished to do so, subject to the following conditions:
|
16 |
-
|
17 |
-
# The above copyright notice and this permission notice shall be included in all
|
18 |
-
# copies or substantial portions of the Software.
|
19 |
-
|
20 |
-
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
21 |
-
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
22 |
-
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
23 |
-
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
24 |
-
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
25 |
-
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
26 |
-
# SOFTWARE.
|
27 |
-
|
28 |
-
import torch
|
29 |
-
|
30 |
-
def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor:
|
31 |
-
"""
|
32 |
-
Left-multiplies MxM @ NxM. Returns NxM.
|
33 |
-
"""
|
34 |
-
res = torch.matmul(vectors4, matrix.T)
|
35 |
-
return res
|
36 |
-
|
37 |
-
|
38 |
-
def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
|
39 |
-
"""
|
40 |
-
Normalize vector lengths.
|
41 |
-
"""
|
42 |
-
return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
|
43 |
-
|
44 |
-
def torch_dot(x: torch.Tensor, y: torch.Tensor):
|
45 |
-
"""
|
46 |
-
Dot product of two tensors.
|
47 |
-
"""
|
48 |
-
return (x * y).sum(-1)
|
49 |
-
|
50 |
-
|
51 |
-
def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length):
|
52 |
-
"""
|
53 |
-
Author: Petr Kellnhofer
|
54 |
-
Intersects rays with the [-1, 1] NDC volume.
|
55 |
-
Returns min and max distance of entry.
|
56 |
-
Returns -1 for no intersection.
|
57 |
-
https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection
|
58 |
-
"""
|
59 |
-
o_shape = rays_o.shape
|
60 |
-
rays_o = rays_o.detach().reshape(-1, 3)
|
61 |
-
rays_d = rays_d.detach().reshape(-1, 3)
|
62 |
-
|
63 |
-
|
64 |
-
bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)]
|
65 |
-
bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)]
|
66 |
-
bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device)
|
67 |
-
is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device)
|
68 |
-
|
69 |
-
# Precompute inverse for stability.
|
70 |
-
invdir = 1 / rays_d
|
71 |
-
sign = (invdir < 0).long()
|
72 |
-
|
73 |
-
# Intersect with YZ plane.
|
74 |
-
tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
|
75 |
-
tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
|
76 |
-
|
77 |
-
# Intersect with XZ plane.
|
78 |
-
tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
|
79 |
-
tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
|
80 |
-
|
81 |
-
# Resolve parallel rays.
|
82 |
-
is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False
|
83 |
-
|
84 |
-
# Use the shortest intersection.
|
85 |
-
tmin = torch.max(tmin, tymin)
|
86 |
-
tmax = torch.min(tmax, tymax)
|
87 |
-
|
88 |
-
# Intersect with XY plane.
|
89 |
-
tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
|
90 |
-
tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
|
91 |
-
|
92 |
-
# Resolve parallel rays.
|
93 |
-
is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False
|
94 |
-
|
95 |
-
# Use the shortest intersection.
|
96 |
-
tmin = torch.max(tmin, tzmin)
|
97 |
-
tmax = torch.min(tmax, tzmax)
|
98 |
-
|
99 |
-
# Mark invalid.
|
100 |
-
tmin[torch.logical_not(is_valid)] = -1
|
101 |
-
tmax[torch.logical_not(is_valid)] = -2
|
102 |
-
|
103 |
-
return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1)
|
104 |
-
|
105 |
-
|
106 |
-
def linspace(start: torch.Tensor, stop: torch.Tensor, num: int):
|
107 |
-
"""
|
108 |
-
Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.
|
109 |
-
Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch.
|
110 |
-
"""
|
111 |
-
# create a tensor of 'num' steps from 0 to 1
|
112 |
-
steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1)
|
113 |
-
|
114 |
-
# reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings
|
115 |
-
# - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript
|
116 |
-
# "cannot statically infer the expected size of a list in this contex", hence the code below
|
117 |
-
for i in range(start.ndim):
|
118 |
-
steps = steps.unsqueeze(-1)
|
119 |
-
|
120 |
-
# the output starts at 'start' and increments until 'stop' in each dimension
|
121 |
-
out = start[None] + steps * (stop - start)[None]
|
122 |
-
|
123 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lrm/models/rendering/utils/ray_marcher.py
DELETED
@@ -1,73 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
7 |
-
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
8 |
-
#
|
9 |
-
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
10 |
-
# property and proprietary rights in and to this material, related
|
11 |
-
# documentation and any modifications thereto. Any use, reproduction,
|
12 |
-
# disclosure or distribution of this material and related documentation
|
13 |
-
# without an express license agreement from NVIDIA CORPORATION or
|
14 |
-
# its affiliates is strictly prohibited.
|
15 |
-
#
|
16 |
-
# Modified by Zexin He
|
17 |
-
# The modifications are subject to the same license as the original.
|
18 |
-
|
19 |
-
|
20 |
-
"""
|
21 |
-
The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths.
|
22 |
-
Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!)
|
23 |
-
"""
|
24 |
-
|
25 |
-
import torch
|
26 |
-
import torch.nn as nn
|
27 |
-
|
28 |
-
|
29 |
-
class MipRayMarcher2(nn.Module):
|
30 |
-
def __init__(self, activation_factory):
|
31 |
-
super().__init__()
|
32 |
-
self.activation_factory = activation_factory
|
33 |
-
|
34 |
-
def run_forward(self, colors, densities, depths, rendering_options):
|
35 |
-
|
36 |
-
deltas = depths[:, :, 1:] - depths[:, :, :-1]
|
37 |
-
colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2
|
38 |
-
densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2
|
39 |
-
depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
# using factory mode for better usability
|
44 |
-
densities_mid = self.activation_factory(rendering_options)(densities_mid)
|
45 |
-
|
46 |
-
density_delta = densities_mid * deltas
|
47 |
-
|
48 |
-
alpha = 1 - torch.exp(-density_delta)
|
49 |
-
|
50 |
-
alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2)
|
51 |
-
weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1]
|
52 |
-
|
53 |
-
composite_rgb = torch.sum(weights * colors_mid, -2)
|
54 |
-
weight_total = weights.sum(2)
|
55 |
-
composite_depth = torch.sum(weights * depths_mid, -2) / weight_total
|
56 |
-
|
57 |
-
# clip the composite to min/max range of depths
|
58 |
-
composite_depth = torch.nan_to_num(composite_depth, float('inf'))
|
59 |
-
composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths))
|
60 |
-
|
61 |
-
if rendering_options.get('white_back', False):
|
62 |
-
composite_rgb = composite_rgb + 1 - weight_total
|
63 |
-
|
64 |
-
# rendered value scale is 0-1, comment out original mipnerf scaling
|
65 |
-
# composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1)
|
66 |
-
|
67 |
-
return composite_rgb, composite_depth, weights
|
68 |
-
|
69 |
-
|
70 |
-
def forward(self, colors, densities, depths, rendering_options):
|
71 |
-
composite_rgb, composite_depth, weights = self.run_forward(colors, densities, depths, rendering_options)
|
72 |
-
|
73 |
-
return composite_rgb, composite_depth, weights
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lrm/models/rendering/utils/ray_sampler_part.py
DELETED
@@ -1,94 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
7 |
-
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
8 |
-
#
|
9 |
-
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
10 |
-
# property and proprietary rights in and to this material, related
|
11 |
-
# documentation and any modifications thereto. Any use, reproduction,
|
12 |
-
# disclosure or distribution of this material and related documentation
|
13 |
-
# without an express license agreement from NVIDIA CORPORATION or
|
14 |
-
# its affiliates is strictly prohibited.
|
15 |
-
#
|
16 |
-
# Modified by Zexin He
|
17 |
-
# The modifications are subject to the same license as the original.
|
18 |
-
|
19 |
-
|
20 |
-
"""
|
21 |
-
The ray sampler is a module that takes in camera matrices and resolution and batches of rays.
|
22 |
-
Expects cam2world matrices that use the OpenCV camera coordinate system conventions.
|
23 |
-
"""
|
24 |
-
|
25 |
-
import torch
|
26 |
-
|
27 |
-
class RaySampler(torch.nn.Module):
|
28 |
-
def __init__(self):
|
29 |
-
super().__init__()
|
30 |
-
self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None
|
31 |
-
|
32 |
-
|
33 |
-
def forward(self, cam2world_matrix, intrinsics, render_size, crop_size, start_x, start_y):
|
34 |
-
"""
|
35 |
-
Create batches of rays and return origins and directions.
|
36 |
-
|
37 |
-
cam2world_matrix: (N, 4, 4)
|
38 |
-
intrinsics: (N, 3, 3)
|
39 |
-
render_size: int
|
40 |
-
|
41 |
-
ray_origins: (N, M, 3)
|
42 |
-
ray_dirs: (N, M, 2)
|
43 |
-
"""
|
44 |
-
|
45 |
-
N, M = cam2world_matrix.shape[0], crop_size**2
|
46 |
-
cam_locs_world = cam2world_matrix[:, :3, 3]
|
47 |
-
fx = intrinsics[:, 0, 0]
|
48 |
-
fy = intrinsics[:, 1, 1]
|
49 |
-
cx = intrinsics[:, 0, 2]
|
50 |
-
cy = intrinsics[:, 1, 2]
|
51 |
-
sk = intrinsics[:, 0, 1]
|
52 |
-
|
53 |
-
uv = torch.stack(torch.meshgrid(
|
54 |
-
torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device),
|
55 |
-
torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device),
|
56 |
-
indexing='ij',
|
57 |
-
))
|
58 |
-
if crop_size < render_size:
|
59 |
-
patch_uv = []
|
60 |
-
for i in range(cam2world_matrix.shape[0]):
|
61 |
-
patch_uv.append(uv.clone()[None, :, start_y:start_y+crop_size, start_x:start_x+crop_size])
|
62 |
-
uv = torch.cat(patch_uv, 0)
|
63 |
-
uv = uv.flip(1).reshape(cam2world_matrix.shape[0], 2, -1).transpose(2, 1)
|
64 |
-
else:
|
65 |
-
uv = uv.flip(0).reshape(2, -1).transpose(1, 0)
|
66 |
-
uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1)
|
67 |
-
# uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1)
|
68 |
-
# uv = uv.flip(1).reshape(cam2world_matrix.shape[0], 2, -1).transpose(2, 1)
|
69 |
-
x_cam = uv[:, :, 0].view(N, -1) * (1./render_size) + (0.5/render_size)
|
70 |
-
y_cam = uv[:, :, 1].view(N, -1) * (1./render_size) + (0.5/render_size)
|
71 |
-
z_cam = torch.ones((N, M), device=cam2world_matrix.device)
|
72 |
-
|
73 |
-
x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y_cam/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam
|
74 |
-
y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam
|
75 |
-
|
76 |
-
cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1).float()
|
77 |
-
|
78 |
-
_opencv2blender = torch.tensor([
|
79 |
-
[1, 0, 0, 0],
|
80 |
-
[0, -1, 0, 0],
|
81 |
-
[0, 0, -1, 0],
|
82 |
-
[0, 0, 0, 1],
|
83 |
-
], dtype=torch.float32, device=cam2world_matrix.device).unsqueeze(0).repeat(N, 1, 1)
|
84 |
-
|
85 |
-
# added float here
|
86 |
-
cam2world_matrix = torch.bmm(cam2world_matrix.float(), _opencv2blender.float())
|
87 |
-
|
88 |
-
world_rel_points = torch.bmm(cam2world_matrix.float(), cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3]
|
89 |
-
|
90 |
-
ray_dirs = world_rel_points - cam_locs_world[:, None, :]
|
91 |
-
ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2)
|
92 |
-
|
93 |
-
ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1)
|
94 |
-
return ray_origins, ray_dirs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lrm/models/rendering/utils/renderer.py
DELETED
@@ -1,314 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
7 |
-
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
8 |
-
#
|
9 |
-
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
10 |
-
# property and proprietary rights in and to this material, related
|
11 |
-
# documentation and any modifications thereto. Any use, reproduction,
|
12 |
-
# disclosure or distribution of this material and related documentation
|
13 |
-
# without an express license agreement from NVIDIA CORPORATION or
|
14 |
-
# its affiliates is strictly prohibited.
|
15 |
-
#
|
16 |
-
# Modified by Zexin He
|
17 |
-
# The modifications are subject to the same license as the original.
|
18 |
-
|
19 |
-
|
20 |
-
"""
|
21 |
-
The renderer is a module that takes in rays, decides where to sample along each
|
22 |
-
ray, and computes pixel colors using the volume rendering equation.
|
23 |
-
"""
|
24 |
-
|
25 |
-
import torch
|
26 |
-
import torch.nn as nn
|
27 |
-
import torch.nn.functional as F
|
28 |
-
|
29 |
-
from .ray_marcher import MipRayMarcher2
|
30 |
-
from . import math_utils
|
31 |
-
|
32 |
-
def generate_planes():
|
33 |
-
"""
|
34 |
-
Defines planes by the three vectors that form the "axes" of the
|
35 |
-
plane. Should work with arbitrary number of planes and planes of
|
36 |
-
arbitrary orientation.
|
37 |
-
|
38 |
-
Bugfix reference: https://github.com/NVlabs/eg3d/issues/67
|
39 |
-
"""
|
40 |
-
return torch.tensor([[[1, 0, 0],
|
41 |
-
[0, 1, 0],
|
42 |
-
[0, 0, 1]],
|
43 |
-
[[1, 0, 0],
|
44 |
-
[0, 0, 1],
|
45 |
-
[0, 1, 0]],
|
46 |
-
[[0, 0, 1],
|
47 |
-
[0, 1, 0],
|
48 |
-
[1, 0, 0]]], dtype=torch.float32)
|
49 |
-
|
50 |
-
def project_onto_planes(planes, coordinates):
|
51 |
-
"""
|
52 |
-
Does a projection of a 3D point onto a batch of 2D planes,
|
53 |
-
returning 2D plane coordinates.
|
54 |
-
|
55 |
-
Takes plane axes of shape n_planes, 3, 3
|
56 |
-
# Takes coordinates of shape N, M, 3
|
57 |
-
# returns projections of shape N*n_planes, M, 2
|
58 |
-
"""
|
59 |
-
N, M, C = coordinates.shape
|
60 |
-
n_planes, _, _ = planes.shape
|
61 |
-
coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3)
|
62 |
-
inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3)
|
63 |
-
coordinates = coordinates.to(inv_planes.device)
|
64 |
-
projections = torch.bmm(coordinates, inv_planes)
|
65 |
-
return projections[..., :2]
|
66 |
-
|
67 |
-
def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None):
|
68 |
-
assert padding_mode == 'zeros'
|
69 |
-
N, n_planes, C, H, W = plane_features.shape
|
70 |
-
_, M, _ = coordinates.shape
|
71 |
-
plane_features = plane_features.view(N*n_planes, C, H, W)
|
72 |
-
|
73 |
-
coordinates = (2/box_warp) * coordinates # add specific box bounds
|
74 |
-
# half added here
|
75 |
-
projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1)
|
76 |
-
# removed float from projected_coordinates
|
77 |
-
output_features = torch.nn.functional.grid_sample(plane_features.float(), projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
|
78 |
-
return output_features
|
79 |
-
|
80 |
-
def sample_from_3dgrid(grid, coordinates):
|
81 |
-
"""
|
82 |
-
Expects coordinates in shape (batch_size, num_points_per_batch, 3)
|
83 |
-
Expects grid in shape (1, channels, H, W, D)
|
84 |
-
(Also works if grid has batch size)
|
85 |
-
Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels)
|
86 |
-
"""
|
87 |
-
batch_size, n_coords, n_dims = coordinates.shape
|
88 |
-
sampled_features = torch.nn.functional.grid_sample(grid.expand(batch_size, -1, -1, -1, -1),
|
89 |
-
coordinates.reshape(batch_size, 1, 1, -1, n_dims),
|
90 |
-
mode='bilinear', padding_mode='zeros', align_corners=False)
|
91 |
-
N, C, H, W, D = sampled_features.shape
|
92 |
-
sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C)
|
93 |
-
return sampled_features
|
94 |
-
|
95 |
-
class ImportanceRenderer(torch.nn.Module):
|
96 |
-
"""
|
97 |
-
Modified original version to filter out-of-box samples as TensoRF does.
|
98 |
-
|
99 |
-
Reference:
|
100 |
-
TensoRF: https://github.com/apchenstu/TensoRF/blob/main/models/tensorBase.py#L277
|
101 |
-
"""
|
102 |
-
def __init__(self):
|
103 |
-
super().__init__()
|
104 |
-
self.activation_factory = self._build_activation_factory()
|
105 |
-
self.ray_marcher = MipRayMarcher2(self.activation_factory)
|
106 |
-
self.plane_axes = generate_planes()
|
107 |
-
|
108 |
-
def _build_activation_factory(self):
|
109 |
-
def activation_factory(options: dict):
|
110 |
-
if options['clamp_mode'] == 'softplus':
|
111 |
-
return lambda x: F.softplus(x - 1) # activation bias of -1 makes things initialize better
|
112 |
-
else:
|
113 |
-
assert False, "Renderer only supports `clamp_mode`=`softplus`!"
|
114 |
-
return activation_factory
|
115 |
-
|
116 |
-
def _forward_pass(self, depths: torch.Tensor, ray_directions: torch.Tensor, ray_origins: torch.Tensor,
|
117 |
-
planes: torch.Tensor, decoder: nn.Module, rendering_options: dict):
|
118 |
-
"""
|
119 |
-
Additional filtering is applied to filter out-of-box samples.
|
120 |
-
Modifications made by Zexin He.
|
121 |
-
"""
|
122 |
-
|
123 |
-
# context related variables
|
124 |
-
batch_size, num_rays, samples_per_ray, _ = depths.shape
|
125 |
-
device = planes.device
|
126 |
-
depths = depths.to(device)
|
127 |
-
ray_directions = ray_directions.to(device)
|
128 |
-
ray_origins = ray_origins.to(device)
|
129 |
-
# define sample points with depths
|
130 |
-
sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3)
|
131 |
-
sample_coordinates = (ray_origins.unsqueeze(-2) + depths * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3)
|
132 |
-
|
133 |
-
# filter out-of-box samples
|
134 |
-
mask_inbox = \
|
135 |
-
(rendering_options['sampler_bbox_min'] <= sample_coordinates) & \
|
136 |
-
(sample_coordinates <= rendering_options['sampler_bbox_max'])
|
137 |
-
mask_inbox = mask_inbox.all(-1)
|
138 |
-
|
139 |
-
# forward model according to all samples
|
140 |
-
_out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options)
|
141 |
-
|
142 |
-
# set out-of-box samples to zeros(rgb) & -inf(sigma)
|
143 |
-
SAFE_GUARD = 3
|
144 |
-
DATA_TYPE = _out['sigma'].dtype
|
145 |
-
colors_pass = torch.zeros(batch_size, num_rays * samples_per_ray, 3, device=device, dtype=DATA_TYPE)
|
146 |
-
densities_pass = torch.nan_to_num(torch.full((batch_size, num_rays * samples_per_ray, 1), -float('inf'), device=device, dtype=DATA_TYPE)) / SAFE_GUARD
|
147 |
-
colors_pass[mask_inbox], densities_pass[mask_inbox] = _out['rgb'][mask_inbox], _out['sigma'][mask_inbox]
|
148 |
-
|
149 |
-
# reshape back
|
150 |
-
colors_pass = colors_pass.reshape(batch_size, num_rays, samples_per_ray, colors_pass.shape[-1])
|
151 |
-
densities_pass = densities_pass.reshape(batch_size, num_rays, samples_per_ray, densities_pass.shape[-1])
|
152 |
-
|
153 |
-
return colors_pass, densities_pass
|
154 |
-
|
155 |
-
def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options):
|
156 |
-
# self.plane_axes = self.plane_axes.to(ray_origins.device)
|
157 |
-
|
158 |
-
if rendering_options['ray_start'] == rendering_options['ray_end'] == 'auto':
|
159 |
-
ray_start, ray_end = math_utils.get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp'])
|
160 |
-
is_ray_valid = ray_end > ray_start
|
161 |
-
if torch.any(is_ray_valid).item():
|
162 |
-
ray_start[~is_ray_valid] = ray_start[is_ray_valid].min()
|
163 |
-
ray_end[~is_ray_valid] = ray_start[is_ray_valid].max()
|
164 |
-
depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'])
|
165 |
-
else:
|
166 |
-
# Create stratified depth samples
|
167 |
-
depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'])
|
168 |
-
|
169 |
-
depths_coarse = depths_coarse.to(planes.device)
|
170 |
-
|
171 |
-
# Coarse Pass
|
172 |
-
colors_coarse, densities_coarse = self._forward_pass(
|
173 |
-
depths=depths_coarse, ray_directions=ray_directions, ray_origins=ray_origins,
|
174 |
-
planes=planes, decoder=decoder, rendering_options=rendering_options)
|
175 |
-
|
176 |
-
# Fine Pass
|
177 |
-
N_importance = rendering_options['depth_resolution_importance']
|
178 |
-
if N_importance > 0:
|
179 |
-
_, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options)
|
180 |
-
|
181 |
-
depths_fine = self.sample_importance(depths_coarse, weights, N_importance)
|
182 |
-
|
183 |
-
colors_fine, densities_fine = self._forward_pass(
|
184 |
-
depths=depths_fine, ray_directions=ray_directions, ray_origins=ray_origins,
|
185 |
-
planes=planes, decoder=decoder, rendering_options=rendering_options)
|
186 |
-
|
187 |
-
all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse,
|
188 |
-
depths_fine, colors_fine, densities_fine)
|
189 |
-
|
190 |
-
# Aggregate
|
191 |
-
rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options)
|
192 |
-
else:
|
193 |
-
rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options)
|
194 |
-
|
195 |
-
return rgb_final, depth_final, weights.sum(2)
|
196 |
-
|
197 |
-
def run_model(self, planes, decoder, sample_coordinates, sample_directions, options):
|
198 |
-
plane_axes = self.plane_axes.to(planes.device)
|
199 |
-
sampled_features = sample_from_planes(plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp'])
|
200 |
-
|
201 |
-
out = decoder(sampled_features, sample_directions)
|
202 |
-
if options.get('density_noise', 0) > 0:
|
203 |
-
out['sigma'] += torch.randn_like(out['sigma']) * options['density_noise']
|
204 |
-
return out
|
205 |
-
|
206 |
-
def run_model_activated(self, planes, decoder, sample_coordinates, sample_directions, options):
|
207 |
-
out = self.run_model(planes, decoder, sample_coordinates, sample_directions, options)
|
208 |
-
out['sigma'] = self.activation_factory(options)(out['sigma'])
|
209 |
-
return out
|
210 |
-
|
211 |
-
def sort_samples(self, all_depths, all_colors, all_densities):
|
212 |
-
_, indices = torch.sort(all_depths, dim=-2)
|
213 |
-
all_depths = torch.gather(all_depths, -2, indices)
|
214 |
-
all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
|
215 |
-
all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
|
216 |
-
return all_depths, all_colors, all_densities
|
217 |
-
|
218 |
-
def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2):
|
219 |
-
all_depths = torch.cat([depths1, depths2], dim = -2)
|
220 |
-
all_colors = torch.cat([colors1, colors2], dim = -2)
|
221 |
-
all_densities = torch.cat([densities1, densities2], dim = -2)
|
222 |
-
|
223 |
-
_, indices = torch.sort(all_depths, dim=-2)
|
224 |
-
all_depths = torch.gather(all_depths, -2, indices)
|
225 |
-
all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
|
226 |
-
all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
|
227 |
-
|
228 |
-
return all_depths, all_colors, all_densities
|
229 |
-
|
230 |
-
def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False):
|
231 |
-
"""
|
232 |
-
Return depths of approximately uniformly spaced samples along rays.
|
233 |
-
"""
|
234 |
-
N, M, _ = ray_origins.shape
|
235 |
-
if disparity_space_sampling:
|
236 |
-
depths_coarse = torch.linspace(0,
|
237 |
-
1,
|
238 |
-
depth_resolution,
|
239 |
-
device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
|
240 |
-
depth_delta = 1/(depth_resolution - 1)
|
241 |
-
depths_coarse += torch.rand_like(depths_coarse) * depth_delta
|
242 |
-
depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse)
|
243 |
-
else:
|
244 |
-
if type(ray_start) == torch.Tensor:
|
245 |
-
depths_coarse = math_utils.linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3)
|
246 |
-
depth_delta = (ray_end - ray_start) / (depth_resolution - 1)
|
247 |
-
depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None]
|
248 |
-
else:
|
249 |
-
depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
|
250 |
-
depth_delta = (ray_end - ray_start)/(depth_resolution - 1)
|
251 |
-
depths_coarse += torch.rand_like(depths_coarse) * depth_delta
|
252 |
-
|
253 |
-
return depths_coarse
|
254 |
-
|
255 |
-
def sample_importance(self, z_vals, weights, N_importance):
|
256 |
-
"""
|
257 |
-
Return depths of importance sampled points along rays. See NeRF importance sampling for more.
|
258 |
-
"""
|
259 |
-
with torch.no_grad():
|
260 |
-
batch_size, num_rays, samples_per_ray, _ = z_vals.shape
|
261 |
-
|
262 |
-
z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray)
|
263 |
-
weights = weights.reshape(batch_size * num_rays, -1) # -1 to account for loss of 1 sample in MipRayMarcher
|
264 |
-
|
265 |
-
# smooth weights
|
266 |
-
weights = torch.nn.functional.max_pool1d(weights.unsqueeze(1).float(), 2, 1, padding=1)
|
267 |
-
weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze()
|
268 |
-
weights = weights + 0.01
|
269 |
-
|
270 |
-
z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:])
|
271 |
-
importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1],
|
272 |
-
N_importance).detach().reshape(batch_size, num_rays, N_importance, 1)
|
273 |
-
return importance_z_vals
|
274 |
-
|
275 |
-
def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5):
|
276 |
-
"""
|
277 |
-
Sample @N_importance samples from @bins with distribution defined by @weights.
|
278 |
-
Inputs:
|
279 |
-
bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2"
|
280 |
-
weights: (N_rays, N_samples_)
|
281 |
-
N_importance: the number of samples to draw from the distribution
|
282 |
-
det: deterministic or not
|
283 |
-
eps: a small number to prevent division by zero
|
284 |
-
Outputs:
|
285 |
-
samples: the sampled samples
|
286 |
-
"""
|
287 |
-
N_rays, N_samples_ = weights.shape
|
288 |
-
weights = weights + eps # prevent division by zero (don't do inplace op!)
|
289 |
-
pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_)
|
290 |
-
cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function
|
291 |
-
cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1)
|
292 |
-
# padded to 0~1 inclusive
|
293 |
-
|
294 |
-
if det:
|
295 |
-
u = torch.linspace(0, 1, N_importance, device=bins.device)
|
296 |
-
u = u.expand(N_rays, N_importance)
|
297 |
-
else:
|
298 |
-
u = torch.rand(N_rays, N_importance, device=bins.device)
|
299 |
-
u = u.contiguous()
|
300 |
-
|
301 |
-
inds = torch.searchsorted(cdf, u, right=True)
|
302 |
-
below = torch.clamp_min(inds-1, 0)
|
303 |
-
above = torch.clamp_max(inds, N_samples_)
|
304 |
-
|
305 |
-
inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance)
|
306 |
-
cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2)
|
307 |
-
bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2)
|
308 |
-
|
309 |
-
denom = cdf_g[...,1]-cdf_g[...,0]
|
310 |
-
denom[denom<eps] = 1 # denom equals 0 means a bin has weight 0, in which case it will not be sampled
|
311 |
-
# anyway, therefore any value for it is fine (set to 1 here)
|
312 |
-
|
313 |
-
samples = bins_g[...,0] + (u-cdf_g[...,0])/denom * (bins_g[...,1]-bins_g[...,0])
|
314 |
-
return samples
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lrm/models/transformer.py
DELETED
@@ -1,135 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
|
8 |
-
import torch
|
9 |
-
import torch.nn as nn
|
10 |
-
|
11 |
-
|
12 |
-
class ModLN(nn.Module):
|
13 |
-
"""
|
14 |
-
Modulation with adaLN.
|
15 |
-
|
16 |
-
References:
|
17 |
-
DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L101
|
18 |
-
"""
|
19 |
-
def __init__(self, inner_dim: int, mod_dim: int, eps: float):
|
20 |
-
super().__init__()
|
21 |
-
self.norm = nn.LayerNorm(inner_dim, eps=eps)
|
22 |
-
self.mlp = nn.Sequential(
|
23 |
-
nn.SiLU(),
|
24 |
-
nn.Linear(mod_dim, inner_dim * 2),
|
25 |
-
)
|
26 |
-
|
27 |
-
@staticmethod
|
28 |
-
def modulate(x, shift, scale):
|
29 |
-
# x: [N, L, D]
|
30 |
-
# shift, scale: [N, D]
|
31 |
-
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
32 |
-
|
33 |
-
def forward(self, x, cond):
|
34 |
-
shift, scale = self.mlp(cond).chunk(2, dim=-1) # [N, D]
|
35 |
-
return self.modulate(self.norm(x), shift, scale) # [N, L, D]
|
36 |
-
|
37 |
-
|
38 |
-
class ConditionModulationBlock(nn.Module):
|
39 |
-
"""
|
40 |
-
Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks.
|
41 |
-
"""
|
42 |
-
# use attention from torch.nn.MultiHeadAttention
|
43 |
-
# Block contains a cross-attention layer, a self-attention layer, and a MLP
|
44 |
-
def __init__(self, inner_dim: int, cond_dim: int, mod_dim: int, num_heads: int, eps: float,
|
45 |
-
attn_drop: float = 0., attn_bias: bool = False,
|
46 |
-
mlp_ratio: float = 4., mlp_drop: float = 0.):
|
47 |
-
super().__init__()
|
48 |
-
self.norm1 = ModLN(inner_dim, mod_dim, eps)
|
49 |
-
self.cross_attn = nn.MultiheadAttention(
|
50 |
-
embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
|
51 |
-
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
52 |
-
self.norm2 = ModLN(inner_dim, mod_dim, eps)
|
53 |
-
self.self_attn = nn.MultiheadAttention(
|
54 |
-
embed_dim=inner_dim, num_heads=num_heads,
|
55 |
-
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
56 |
-
self.norm3 = ModLN(inner_dim, mod_dim, eps)
|
57 |
-
self.mlp = nn.Sequential(
|
58 |
-
nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
|
59 |
-
nn.GELU(),
|
60 |
-
nn.Dropout(mlp_drop),
|
61 |
-
nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
|
62 |
-
nn.Dropout(mlp_drop),
|
63 |
-
)
|
64 |
-
|
65 |
-
def forward(self, x, cond, mod):
|
66 |
-
# x: [N, L, D]
|
67 |
-
# cond: [N, L_cond, D_cond]
|
68 |
-
# mod: [N, D_mod]
|
69 |
-
x = x + self.cross_attn(self.norm1(x, mod), cond, cond, need_weights=False)[0]
|
70 |
-
before_sa = self.norm2(x, mod)
|
71 |
-
x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
|
72 |
-
x = x + self.mlp(self.norm3(x, mod))
|
73 |
-
return x
|
74 |
-
|
75 |
-
|
76 |
-
class TriplaneTransformer(nn.Module):
|
77 |
-
"""
|
78 |
-
Transformer with condition and modulation that generates a triplane representation.
|
79 |
-
|
80 |
-
Reference:
|
81 |
-
Timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L486
|
82 |
-
"""
|
83 |
-
def __init__(self, inner_dim: int, image_feat_dim: int, camera_embed_dim: int,
|
84 |
-
triplane_low_res: int, triplane_high_res: int, triplane_dim: int,
|
85 |
-
num_layers: int, num_heads: int,
|
86 |
-
eps: float = 1e-6):
|
87 |
-
super().__init__()
|
88 |
-
|
89 |
-
# attributes
|
90 |
-
self.triplane_low_res = triplane_low_res
|
91 |
-
self.triplane_high_res = triplane_high_res
|
92 |
-
self.triplane_dim = triplane_dim
|
93 |
-
|
94 |
-
# modules
|
95 |
-
# initialize pos_embed with 1/sqrt(dim) * N(0, 1)
|
96 |
-
self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, inner_dim) * (1. / inner_dim) ** 0.5)
|
97 |
-
self.layers = nn.ModuleList([
|
98 |
-
ConditionModulationBlock(
|
99 |
-
inner_dim=inner_dim, cond_dim=image_feat_dim, mod_dim=camera_embed_dim, num_heads=num_heads, eps=eps)
|
100 |
-
for _ in range(num_layers)
|
101 |
-
])
|
102 |
-
self.norm = nn.LayerNorm(inner_dim, eps=eps)
|
103 |
-
self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0)
|
104 |
-
|
105 |
-
def forward(self, image_feats, camera_embeddings):
|
106 |
-
# image_feats: [N, L_cond, D_cond]
|
107 |
-
# camera_embeddings: [N, D_mod]
|
108 |
-
|
109 |
-
assert image_feats.shape[0] == camera_embeddings.shape[0], \
|
110 |
-
f"Mismatched batch size: {image_feats.shape[0]} vs {camera_embeddings.shape[0]}"
|
111 |
-
|
112 |
-
N = image_feats.shape[0]
|
113 |
-
H = W = self.triplane_low_res
|
114 |
-
L = 3 * H * W
|
115 |
-
|
116 |
-
x = self.pos_embed.repeat(N, 1, 1) # [N, L, D]
|
117 |
-
for layer in self.layers:
|
118 |
-
x = layer(x, image_feats, camera_embeddings)
|
119 |
-
x = self.norm(x)
|
120 |
-
|
121 |
-
# separate each plane and apply deconv
|
122 |
-
x = x.view(N, 3, H, W, -1)
|
123 |
-
x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W]
|
124 |
-
x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W]
|
125 |
-
x = self.deconv(x) # [3*N, D', H', W']
|
126 |
-
x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W']
|
127 |
-
x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W']
|
128 |
-
x = x.contiguous()
|
129 |
-
|
130 |
-
assert self.triplane_high_res == x.shape[-2], \
|
131 |
-
f"Output triplane resolution does not match with expected: {x.shape[-2]} vs {self.triplane_high_res}"
|
132 |
-
assert self.triplane_dim == x.shape[-3], \
|
133 |
-
f"Output triplane dimension does not match with expected: {x.shape[-3]} vs {self.triplane_dim}"
|
134 |
-
|
135 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|