jadechoghari commited on
Commit
e87a6fe
·
verified ·
1 Parent(s): a3e0040

Delete lrm

Browse files
lrm/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
 
 
 
 
 
 
lrm/cam_utils.py DELETED
@@ -1,138 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
-
8
- import torch
9
- import numpy as np
10
- import math
11
-
12
- """
13
- R: (N, 3, 3)
14
- T: (N, 3)
15
- E: (N, 4, 4)
16
- vector: (N, 3)
17
- """
18
-
19
-
20
- def compose_extrinsic_R_T(R: torch.Tensor, T: torch.Tensor):
21
- """
22
- Compose the standard form extrinsic matrix from R and T.
23
- Batched I/O.
24
- """
25
- RT = torch.cat((R, T.unsqueeze(-1)), dim=-1)
26
- return compose_extrinsic_RT(RT)
27
-
28
-
29
- def compose_extrinsic_RT(RT: torch.Tensor):
30
- """
31
- Compose the standard form extrinsic matrix from RT.
32
- Batched I/O.
33
- """
34
- return torch.cat([
35
- RT,
36
- torch.tensor([[[0, 0, 0, 1]]], dtype=torch.float32).repeat(RT.shape[0], 1, 1).to(RT.device)
37
- ], dim=1)
38
-
39
-
40
- def decompose_extrinsic_R_T(E: torch.Tensor):
41
- """
42
- Decompose the standard extrinsic matrix into R and T.
43
- Batched I/O.
44
- """
45
- RT = decompose_extrinsic_RT(E)
46
- return RT[:, :, :3], RT[:, :, 3]
47
-
48
-
49
- def decompose_extrinsic_RT(E: torch.Tensor):
50
- """
51
- Decompose the standard extrinsic matrix into RT.
52
- Batched I/O.
53
- """
54
- return E[:, :3, :]
55
-
56
-
57
- def get_normalized_camera_intrinsics(intrinsics: torch.Tensor):
58
- """
59
- intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
60
- Return batched fx, fy, cx, cy
61
- """
62
- fx, fy = intrinsics[:, 0, 0], intrinsics[:, 0, 1]
63
- cx, cy = intrinsics[:, 1, 0], intrinsics[:, 1, 1]
64
- width, height = intrinsics[:, 2, 0], intrinsics[:, 2, 1]
65
- fx, fy = fx / width, fy / height
66
- cx, cy = cx / width, cy / height
67
- return fx, fy, cx, cy
68
-
69
-
70
- def build_camera_principle(RT: torch.Tensor, intrinsics: torch.Tensor):
71
- """
72
- RT: (N, 3, 4)
73
- intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
74
- """
75
- fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics)
76
- return torch.cat([
77
- RT.reshape(-1, 12),
78
- fx.unsqueeze(-1), fy.unsqueeze(-1), cx.unsqueeze(-1), cy.unsqueeze(-1),
79
- ], dim=-1)
80
-
81
-
82
- def build_camera_standard(RT: torch.Tensor, intrinsics: torch.Tensor):
83
- """
84
- RT: (N, 3, 4)
85
- intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
86
- """
87
- E = compose_extrinsic_RT(RT)
88
- fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics)
89
- I = torch.stack([
90
- torch.stack([fx, torch.zeros_like(fx), cx], dim=-1),
91
- torch.stack([torch.zeros_like(fy), fy, cy], dim=-1),
92
- torch.tensor([[0, 0, 1]], dtype=torch.float32, device=RT.device).repeat(RT.shape[0], 1),
93
- ], dim=1)
94
- return torch.cat([
95
- E.reshape(-1, 16),
96
- I.reshape(-1, 9),
97
- ], dim=-1)
98
-
99
-
100
- def center_looking_at_camera_pose(camera_position: torch.Tensor, look_at: torch.Tensor = None, up_world: torch.Tensor = None):
101
- """
102
- camera_position: (M, 3)
103
- look_at: (3)
104
- up_world: (3)
105
- return: (M, 3, 4)
106
- """
107
- # by default, looking at the origin and world up is pos-z
108
- if look_at is None:
109
- look_at = torch.tensor([0, 0, 0], dtype=torch.float32)
110
- if up_world is None:
111
- up_world = torch.tensor([0, 0, 1], dtype=torch.float32)
112
- look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1)
113
- up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1)
114
-
115
- z_axis = camera_position - look_at
116
- z_axis = z_axis / z_axis.norm(dim=-1, keepdim=True)
117
- x_axis = torch.cross(up_world, z_axis)
118
- x_axis = x_axis / x_axis.norm(dim=-1, keepdim=True)
119
- y_axis = torch.cross(z_axis, x_axis)
120
- y_axis = y_axis / y_axis.norm(dim=-1, keepdim=True)
121
- extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1)
122
- return extrinsics
123
-
124
- def get_surrounding_views(M, radius, elevation):
125
- # convert spherical coordinates (radius, azimuth, elevation) to Cartesian coordinates (x, y, z).
126
- camera_positions = []
127
- rand_theta= np.random.uniform(0, np.pi/180)
128
- elevation = math.radians(elevation)
129
- for i in range(M):
130
- theta = 2 * math.pi * i / M + rand_theta
131
- x = radius * math.cos(theta) * math.cos(elevation)
132
- y = radius * math.sin(theta) * math.cos(elevation)
133
- z = radius * math.sin(elevation)
134
- camera_positions.append([x, y, z])
135
- camera_positions = torch.tensor(camera_positions, dtype=torch.float32)
136
- extrinsics = center_looking_at_camera_pose(camera_positions)
137
-
138
- return extrinsics
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lrm/inferrer.py DELETED
@@ -1,232 +0,0 @@
1
- import torch
2
- import math
3
- import os
4
- import imageio
5
- import mcubes
6
- import trimesh
7
- import numpy as np
8
- import argparse
9
- from torchvision.utils import save_image
10
- from PIL import Image
11
- import glob
12
- from .models.generator import LRMGenerator # Make sure this import is correct
13
- from .cam_utils import build_camera_principle, build_camera_standard, center_looking_at_camera_pose # Make sure this import is correct
14
- from functools import partial
15
- from rembg import remove, new_session
16
- from kiui.op import recenter
17
- import kiui
18
-
19
- class LRMInferrer:
20
- def __init__(self, model_name: str, resume: str):
21
- print("Initializing LRMInferrer")
22
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
- _model_kwargs = {'camera_embed_dim': 1024, 'rendering_samples_per_ray': 128, 'transformer_dim': 1024, 'transformer_layers': 16, 'transformer_heads': 16, 'triplane_low_res': 32, 'triplane_high_res': 64, 'triplane_dim': 80, 'encoder_freeze': False}
24
-
25
- self.model = self._build_model(_model_kwargs).eval().to(self.device)
26
- checkpoint = torch.load(resume, map_location='cpu')
27
- state_dict = checkpoint['model_state_dict']
28
- self.model.load_state_dict(state_dict)
29
- del checkpoint, state_dict
30
- torch.cuda.empty_cache()
31
-
32
- def __enter__(self):
33
- print("Entering context")
34
- return self
35
-
36
- def __exit__(self, exc_type, exc_val, exc_tb):
37
- print("Exiting context")
38
- if exc_type:
39
- print(f"Exception type: {exc_type}")
40
- print(f"Exception value: {exc_val}")
41
- print(f"Traceback: {exc_tb}")
42
-
43
- def _build_model(self, model_kwargs):
44
- print("Building model")
45
- model = LRMGenerator(**model_kwargs).to(self.device)
46
- print("Loaded model from checkpoint")
47
- return model
48
-
49
- @staticmethod
50
- def get_surrounding_views(M, radius, elevation):
51
- camera_positions = []
52
- rand_theta = np.random.uniform(0, np.pi/180)
53
- elevation = math.radians(elevation)
54
- for i in range(M):
55
- theta = 2 * math.pi * i / M + rand_theta
56
- x = radius * math.cos(theta) * math.cos(elevation)
57
- y = radius * math.sin(theta) * math.cos(elevation)
58
- z = radius * math.sin(elevation)
59
- camera_positions.append([x, y, z])
60
- camera_positions = torch.tensor(camera_positions, dtype=torch.float32)
61
- extrinsics = center_looking_at_camera_pose(camera_positions)
62
- return extrinsics
63
-
64
- @staticmethod
65
- def _default_intrinsics():
66
- fx = fy = 384
67
- cx = cy = 256
68
- w = h = 512
69
- intrinsics = torch.tensor([
70
- [fx, fy],
71
- [cx, cy],
72
- [w, h],
73
- ], dtype=torch.float32)
74
- return intrinsics
75
-
76
- def _default_source_camera(self, batch_size: int = 1):
77
- dist_to_center = 1.5
78
- canonical_camera_extrinsics = torch.tensor([[
79
- [0, 0, 1, 1],
80
- [1, 0, 0, 0],
81
- [0, 1, 0, 0],
82
- ]], dtype=torch.float32)
83
- canonical_camera_intrinsics = self._default_intrinsics().unsqueeze(0)
84
- source_camera = build_camera_principle(canonical_camera_extrinsics, canonical_camera_intrinsics)
85
- return source_camera.repeat(batch_size, 1)
86
-
87
- def _default_render_cameras(self, batch_size: int = 1):
88
- render_camera_extrinsics = self.get_surrounding_views(160, 1.5, 0)
89
- render_camera_intrinsics = self._default_intrinsics().unsqueeze(0).repeat(render_camera_extrinsics.shape[0], 1, 1)
90
- render_cameras = build_camera_standard(render_camera_extrinsics, render_camera_intrinsics)
91
- return render_cameras.unsqueeze(0).repeat(batch_size, 1, 1)
92
-
93
- @staticmethod
94
- def images_to_video(images, output_path, fps, verbose=False):
95
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
96
- frames = []
97
- for i in range(images.shape[0]):
98
- frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
99
- assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
100
- f"Frame shape mismatch: {frame.shape} vs {images.shape}"
101
- assert frame.min() >= 0 and frame.max() <= 255, \
102
- f"Frame value out of range: {frame.min()} ~ {frame.max()}"
103
- frames.append(frame)
104
- imageio.mimwrite(output_path, np.stack(frames), fps=fps)
105
- if verbose:
106
- print(f"Saved video to {output_path}")
107
-
108
- def infer_single(self, image: torch.Tensor, render_size: int, mesh_size: int, export_video: bool, export_mesh: bool):
109
- print("infer_single called")
110
- mesh_thres = 1.0
111
- chunk_size = 2
112
- batch_size = 1
113
-
114
- source_camera = self._default_source_camera(batch_size).to(self.device)
115
- render_cameras = self._default_render_cameras(batch_size).to(self.device)
116
-
117
- with torch.no_grad():
118
- planes = self.model.forward(image, source_camera)
119
- results = {}
120
-
121
- if export_video:
122
- print("Starting export_video")
123
- frames = []
124
- for i in range(0, render_cameras.shape[1], chunk_size):
125
- print(f"Processing chunk {i} to {i + chunk_size}")
126
- frames.append(
127
- self.model.synthesizer(
128
- planes,
129
- render_cameras[:, i:i+chunk_size],
130
- render_size,
131
- render_size,
132
- 0,
133
- 0
134
- )
135
- )
136
- frames = {
137
- k: torch.cat([r[k] for r in frames], dim=1)
138
- for k in frames[0].keys()
139
- }
140
- results.update({
141
- 'frames': frames,
142
- })
143
- print("Finished export_video")
144
-
145
- if export_mesh:
146
- print("Starting export_mesh")
147
- grid_out = self.model.synthesizer.forward_grid(
148
- planes=planes,
149
- grid_size=mesh_size,
150
- )
151
- vtx, faces = mcubes.marching_cubes(grid_out['sigma'].float().squeeze(0).squeeze(-1).cpu().numpy(), mesh_thres)
152
- vtx = vtx / (mesh_size - 1) * 2 - 1
153
- vtx_tensor = torch.tensor(vtx, dtype=torch.float32, device=self.device).unsqueeze(0)
154
- vtx_colors = self.model.synthesizer.forward_points(planes, vtx_tensor)['rgb'].float().squeeze(0).cpu().numpy()
155
- vtx_colors = (vtx_colors * 255).astype(np.uint8)
156
- mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors)
157
- results.update({
158
- 'mesh': mesh,
159
- })
160
- print("Finished export_mesh")
161
-
162
- return results
163
-
164
- def infer(self, source_image: str, dump_path: str, source_size: int, render_size: int, mesh_size: int, export_video: bool, export_mesh: bool):
165
- print("infer called")
166
- session = new_session("isnet-general-use")
167
- rembg_remove = partial(remove, session=session)
168
- image_name = os.path.basename(source_image)
169
- uid = image_name.split('.')[0]
170
-
171
- image = kiui.read_image(source_image, mode='uint8')
172
- image = rembg_remove(image)
173
- mask = rembg_remove(image, only_mask=True)
174
- image = recenter(image, mask, border_ratio=0.20)
175
- os.makedirs(dump_path, exist_ok=True)
176
-
177
- image = torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0) / 255.0
178
- if image.shape[1] == 4:
179
- image = image[:, :3, ...] * image[:, 3:, ...] + (1 - image[:, 3:, ...])
180
- image = torch.nn.functional.interpolate(image, size=(source_size, source_size), mode='bicubic', align_corners=True)
181
- image = torch.clamp(image, 0, 1)
182
- save_image(image, os.path.join(dump_path, f'{uid}.png'))
183
-
184
- results = self.infer_single(
185
- image.cuda(),
186
- render_size=render_size,
187
- mesh_size=mesh_size,
188
- export_video=export_video,
189
- export_mesh=export_mesh,
190
- )
191
-
192
- if 'frames' in results:
193
- renderings = results['frames']
194
- for k, v in renderings.items():
195
- if k == 'images_rgb':
196
- self.images_to_video(
197
- v[0],
198
- os.path.join(dump_path, f'{uid}.mp4'),
199
- fps=40,
200
- )
201
- print(f"Export video success to {dump_path}")
202
-
203
- if 'mesh' in results:
204
- mesh = results['mesh']
205
- mesh.export(os.path.join(dump_path, f'{uid}.obj'), 'obj')
206
-
207
- if __name__ == '__main__':
208
- parser = argparse.ArgumentParser()
209
- parser.add_argument('--model_name', type=str, default='lrm-base-obj-v1')
210
- parser.add_argument('--source_path', type=str, default='./assets/cat.png')
211
- parser.add_argument('--dump_path', type=str, default='./results/single_image')
212
- parser.add_argument('--source_size', type=int, default=512)
213
- parser.add_argument('--render_size', type=int, default=384)
214
- parser.add_argument('--mesh_size', type=int, default=512)
215
- parser.add_argument('--export_video', action='store_true')
216
- parser.add_argument('--export_mesh', action='store_true')
217
- parser.add_argument('--resume', type=str, required=True, help='Path to a checkpoint to resume training from')
218
- args = parser.parse_args()
219
-
220
- with LRMInferrer(model_name=args.model_name, resume=args.resume) as inferrer:
221
- with torch.autocast(device_type="cuda", cache_enabled=False, dtype=torch.float32):
222
- print("Start inference for image:", args.source_path)
223
- inferrer.infer(
224
- source_image=args.source_path,
225
- dump_path=args.dump_path,
226
- source_size=args.source_size,
227
- render_size=args.render_size,
228
- mesh_size=args.mesh_size,
229
- export_video=args.export_video,
230
- export_mesh=args.export_mesh,
231
- )
232
- print("Finished inference for image:", args.source_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lrm/models/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
 
 
 
 
 
 
lrm/models/encoders/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
 
 
 
 
 
 
lrm/models/encoders/dino_wrapper2.py DELETED
@@ -1,51 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
-
8
- import torch.nn as nn
9
- from transformers import ViTImageProcessor, ViTModel, AutoImageProcessor, AutoModel, Dinov2Model
10
-
11
- class DinoWrapper(nn.Module):
12
- """
13
- Dino v1 wrapper using huggingface transformer implementation.
14
- """
15
- def __init__(self, model_name: str, freeze: bool = True):
16
- super().__init__()
17
- self.model, self.processor = self._build_dino(model_name)
18
- if freeze:
19
- self._freeze()
20
-
21
- def forward(self, image):
22
- # image: [N, C, H, W], on cpu
23
- # RGB image with [0,1] scale and properly sized
24
- inputs = self.processor(images=image.float(), return_tensors="pt", do_rescale=False, do_resize=False).to(self.model.device)
25
- # This resampling of positional embedding uses bicubic interpolation
26
- outputs = self.model(**inputs)
27
- last_hidden_states = outputs.last_hidden_state
28
- return last_hidden_states
29
-
30
- def _freeze(self):
31
- print(f"======== Freezing DinoWrapper ========")
32
- self.model.eval()
33
- for name, param in self.model.named_parameters():
34
- param.requires_grad = False
35
-
36
- @staticmethod
37
- def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5):
38
- import requests
39
- try:
40
- processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
41
- processor.do_center_crop = False
42
- model = AutoModel.from_pretrained('facebook/dinov2-base')
43
- return model, processor
44
- except requests.exceptions.ProxyError as err:
45
- if proxy_error_retries > 0:
46
- print(f"Huggingface ProxyError: Retrying in {proxy_error_cooldown} seconds...")
47
- import time
48
- time.sleep(proxy_error_cooldown)
49
- return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown)
50
- else:
51
- raise err
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lrm/models/generator.py DELETED
@@ -1,87 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
-
8
- import torch.nn as nn
9
-
10
- from .encoders.dino_wrapper2 import DinoWrapper
11
- from .transformer import TriplaneTransformer
12
- from .rendering.synthesizer_part import TriplaneSynthesizer
13
-
14
-
15
- class CameraEmbedder(nn.Module):
16
- """
17
- Embed camera features to a high-dimensional vector.
18
-
19
- Reference:
20
- DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L27
21
- """
22
- def __init__(self, raw_dim: int, embed_dim: int):
23
- super().__init__()
24
- self.mlp = nn.Sequential(
25
- nn.Linear(raw_dim, embed_dim),
26
- nn.SiLU(),
27
- nn.Linear(embed_dim, embed_dim),
28
- )
29
-
30
- def forward(self, x):
31
- return self.mlp(x)
32
-
33
-
34
- class LRMGenerator(nn.Module):
35
- """
36
- Full model of the large reconstruction model.
37
- """
38
- def __init__(self, camera_embed_dim: int, rendering_samples_per_ray: int,
39
- transformer_dim: int, transformer_layers: int, transformer_heads: int,
40
- triplane_low_res: int, triplane_high_res: int, triplane_dim: int,
41
- encoder_freeze: bool = True, encoder_model_name: str = 'facebook/dinov2-base', encoder_feat_dim: int = 768):
42
- super().__init__()
43
-
44
- # attributes
45
- self.encoder_feat_dim = encoder_feat_dim
46
- self.camera_embed_dim = camera_embed_dim
47
-
48
- # modules
49
- self.encoder = DinoWrapper(
50
- model_name=encoder_model_name,
51
- freeze=encoder_freeze,
52
- )
53
- self.camera_embedder = CameraEmbedder(
54
- raw_dim=12+4, embed_dim=camera_embed_dim,
55
- )
56
- self.transformer = TriplaneTransformer(
57
- inner_dim=transformer_dim, num_layers=transformer_layers, num_heads=transformer_heads,
58
- image_feat_dim=encoder_feat_dim,
59
- camera_embed_dim=camera_embed_dim,
60
- triplane_low_res=triplane_low_res, triplane_high_res=triplane_high_res, triplane_dim=triplane_dim,
61
- )
62
- self.synthesizer = TriplaneSynthesizer(
63
- triplane_dim=triplane_dim, samples_per_ray=rendering_samples_per_ray,
64
- )
65
-
66
- def forward(self, image, camera):
67
- # image: [N, C_img, H_img, W_img]
68
- # camera: [N, D_cam_raw]
69
- assert image.shape[0] == camera.shape[0], "Batch size mismatch for image and camera"
70
- N = image.shape[0]
71
-
72
- # encode image
73
- image_feats = self.encoder(image)
74
- assert image_feats.shape[-1] == self.encoder_feat_dim, \
75
- f"Feature dimension mismatch: {image_feats.shape[-1]} vs {self.encoder_feat_dim}"
76
-
77
- # embed camera
78
- camera_embeddings = self.camera_embedder(camera)
79
- assert camera_embeddings.shape[-1] == self.camera_embed_dim, \
80
- f"Feature dimension mismatch: {camera_embeddings.shape[-1]} vs {self.camera_embed_dim}"
81
-
82
- # transformer generating planes
83
- planes = self.transformer(image_feats, camera_embeddings)
84
- assert planes.shape[0] == N, "Batch size mismatch for planes"
85
- assert planes.shape[1] == 3, "Planes should have 3 channels"
86
- return planes
87
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lrm/models/rendering/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
 
 
 
 
 
 
lrm/models/rendering/synthesizer_part.py DELETED
@@ -1,194 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
-
8
- import itertools
9
- import torch
10
- import torch.nn as nn
11
-
12
- from .utils.renderer import ImportanceRenderer
13
- from .utils.ray_sampler_part import RaySampler
14
-
15
-
16
- class OSGDecoder(nn.Module):
17
- """
18
- Triplane decoder that gives RGB and sigma values from sampled features.
19
- Using ReLU here instead of Softplus in the original implementation.
20
-
21
- Reference:
22
- EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112
23
- """
24
- def __init__(self, n_features: int,
25
- hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU):
26
- super().__init__()
27
- self.net = nn.Sequential(
28
- nn.Linear(3 * n_features, hidden_dim),
29
- activation(),
30
- *itertools.chain(*[[
31
- nn.Linear(hidden_dim, hidden_dim),
32
- activation(),
33
- ] for _ in range(num_layers - 2)]),
34
- nn.Linear(hidden_dim, 1 + 3),
35
- )
36
- # init all bias to zero
37
- for m in self.modules():
38
- if isinstance(m, nn.Linear):
39
- nn.init.zeros_(m.bias)
40
-
41
- def forward(self, sampled_features, ray_directions):
42
- # Aggregate features by mean
43
- # sampled_features = sampled_features.mean(1)
44
- # Aggregate features by concatenation
45
- _N, n_planes, _M, _C = sampled_features.shape
46
- sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)
47
- x = sampled_features
48
-
49
- N, M, C = x.shape
50
- x = x.contiguous().view(N*M, C)
51
-
52
- x = self.net(x)
53
- x = x.view(N, M, -1)
54
- rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
55
- sigma = x[..., 0:1]
56
-
57
- return {'rgb': rgb, 'sigma': sigma}
58
-
59
-
60
- class TriplaneSynthesizer(nn.Module):
61
- """
62
- Synthesizer that renders a triplane volume with planes and a camera.
63
-
64
- Reference:
65
- EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19
66
- """
67
-
68
- DEFAULT_RENDERING_KWARGS = {
69
- 'ray_start': 'auto',
70
- 'ray_end': 'auto',
71
- 'box_warp': 2.,
72
- 'white_back': True,
73
- 'disparity_space_sampling': False,
74
- 'clamp_mode': 'softplus',
75
- 'sampler_bbox_min': -1.,
76
- 'sampler_bbox_max': 1.,
77
- }
78
-
79
- def __init__(self, triplane_dim: int, samples_per_ray: int):
80
- super().__init__()
81
-
82
- # attributes
83
- self.triplane_dim = triplane_dim
84
- self.rendering_kwargs = {
85
- **self.DEFAULT_RENDERING_KWARGS,
86
- 'depth_resolution': samples_per_ray // 2,
87
- 'depth_resolution_importance': samples_per_ray // 2,
88
- }
89
-
90
- # renderings
91
- self.renderer = ImportanceRenderer()
92
- self.ray_sampler = RaySampler()
93
-
94
- # modules
95
- self.decoder = OSGDecoder(n_features=triplane_dim)
96
-
97
- def forward(self, planes, cameras, render_size: int, crop_size: int, start_x: int, start_y:int):
98
- # planes: (N, 3, D', H', W')
99
- # cameras: (N, M, D_cam)
100
- # render_size: int
101
- assert planes.shape[0] == cameras.shape[0], "Batch size mismatch for planes and cameras"
102
- N, M = cameras.shape[:2]
103
- cam2world_matrix = cameras[..., :16].view(N, M, 4, 4)
104
- intrinsics = cameras[..., 16:25].view(N, M, 3, 3)
105
-
106
- # Create a batch of rays for volume rendering
107
- ray_origins, ray_directions = self.ray_sampler(
108
- cam2world_matrix=cam2world_matrix.reshape(-1, 4, 4),
109
- intrinsics=intrinsics.reshape(-1, 3, 3),
110
- render_size=render_size,
111
- crop_size = crop_size,
112
- start_x = start_x,
113
- start_y = start_y
114
- )
115
- assert N*M == ray_origins.shape[0], "Batch size mismatch for ray_origins"
116
- assert ray_origins.dim() == 3, "ray_origins should be 3-dimensional"
117
- # Perform volume rendering
118
- rgb_samples, depth_samples, weights_samples = self.renderer(
119
- planes.repeat_interleave(M, dim=0), self.decoder, ray_origins, ray_directions, self.rendering_kwargs,
120
- )
121
-
122
- # Reshape into 'raw' neural-rendered image
123
- Himg = Wimg = crop_size
124
- rgb_images = rgb_samples.permute(0, 2, 1).reshape(N, M, rgb_samples.shape[-1], Himg, Wimg).contiguous()
125
- depth_images = depth_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg)
126
- weight_images = weights_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg)
127
-
128
- return {
129
- 'images_rgb': rgb_images,
130
- 'images_depth': depth_images,
131
- 'images_weight': weight_images,
132
- }
133
-
134
- def forward_grid(self, planes, grid_size: int, aabb: torch.Tensor = None):
135
- # planes: (N, 3, D', H', W')
136
- # grid_size: int
137
- # aabb: (N, 2, 3)
138
- if aabb is None:
139
- aabb = torch.tensor([
140
- [self.rendering_kwargs['sampler_bbox_min']] * 3,
141
- [self.rendering_kwargs['sampler_bbox_max']] * 3,
142
- ], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat(planes.shape[0], 1, 1)
143
- assert planes.shape[0] == aabb.shape[0], "Batch size mismatch for planes and aabb"
144
- N = planes.shape[0]
145
-
146
- # create grid points for triplane query
147
- grid_points = []
148
- for i in range(N):
149
- grid_points.append(torch.stack(torch.meshgrid(
150
- torch.linspace(aabb[i, 0, 0], aabb[i, 1, 0], grid_size, device=planes.device),
151
- torch.linspace(aabb[i, 0, 1], aabb[i, 1, 1], grid_size, device=planes.device),
152
- torch.linspace(aabb[i, 0, 2], aabb[i, 1, 2], grid_size, device=planes.device),
153
- indexing='ij',
154
- ), dim=-1).reshape(-1, 3))
155
- cube_grid = torch.stack(grid_points, dim=0).to(planes.device)
156
-
157
- features = self.forward_points(planes, cube_grid)
158
-
159
- # reshape into grid
160
- features = {
161
- k: v.reshape(N, grid_size, grid_size, grid_size, -1)
162
- for k, v in features.items()
163
- }
164
- return features
165
-
166
- def forward_points(self, planes, points: torch.Tensor, chunk_size: int = 2**20):
167
- # planes: (N, 3, D', H', W')
168
- # points: (N, P, 3)
169
- N, P = points.shape[:2]
170
-
171
- # query triplane in chunks
172
- outs = []
173
- for i in range(0, points.shape[1], chunk_size):
174
- chunk_points = points[:, i:i+chunk_size]
175
-
176
- # query triplane
177
- chunk_out = self.renderer.run_model_activated(
178
- planes=planes,
179
- decoder=self.decoder,
180
- sample_coordinates=chunk_points,
181
- sample_directions=torch.zeros_like(chunk_points),
182
- options=self.rendering_kwargs,
183
- )
184
- outs.append(chunk_out)
185
-
186
- # concatenate the outputs
187
- point_features = {
188
- k: torch.cat([out[k] for out in outs], dim=1)
189
- for k in outs[0].keys()
190
- }
191
-
192
- sig = point_features['sigma']
193
- print(sig.mean(), sig.max(), sig.min())
194
- return point_features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lrm/models/rendering/utils/__init__.py DELETED
@@ -1,14 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
7
- # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
8
- #
9
- # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
10
- # property and proprietary rights in and to this material, related
11
- # documentation and any modifications thereto. Any use, reproduction,
12
- # disclosure or distribution of this material and related documentation
13
- # without an express license agreement from NVIDIA CORPORATION or
14
- # its affiliates is strictly prohibited.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lrm/models/rendering/utils/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (192 Bytes)
 
lrm/models/rendering/utils/__pycache__/math_utils.cpython-310.pyc DELETED
Binary file (2.76 kB)
 
lrm/models/rendering/utils/__pycache__/ray_marcher.cpython-310.pyc DELETED
Binary file (2.01 kB)
 
lrm/models/rendering/utils/__pycache__/ray_sampler_part.cpython-310.pyc DELETED
Binary file (2.75 kB)
 
lrm/models/rendering/utils/__pycache__/renderer.cpython-310.pyc DELETED
Binary file (10.5 kB)
 
lrm/models/rendering/utils/math_utils.py DELETED
@@ -1,123 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- # MIT License
7
-
8
- # Copyright (c) 2022 Petr Kellnhofer
9
-
10
- # Permission is hereby granted, free of charge, to any person obtaining a copy
11
- # of this software and associated documentation files (the "Software"), to deal
12
- # in the Software without restriction, including without limitation the rights
13
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14
- # copies of the Software, and to permit persons to whom the Software is
15
- # furnished to do so, subject to the following conditions:
16
-
17
- # The above copyright notice and this permission notice shall be included in all
18
- # copies or substantial portions of the Software.
19
-
20
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26
- # SOFTWARE.
27
-
28
- import torch
29
-
30
- def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor:
31
- """
32
- Left-multiplies MxM @ NxM. Returns NxM.
33
- """
34
- res = torch.matmul(vectors4, matrix.T)
35
- return res
36
-
37
-
38
- def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
39
- """
40
- Normalize vector lengths.
41
- """
42
- return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
43
-
44
- def torch_dot(x: torch.Tensor, y: torch.Tensor):
45
- """
46
- Dot product of two tensors.
47
- """
48
- return (x * y).sum(-1)
49
-
50
-
51
- def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length):
52
- """
53
- Author: Petr Kellnhofer
54
- Intersects rays with the [-1, 1] NDC volume.
55
- Returns min and max distance of entry.
56
- Returns -1 for no intersection.
57
- https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection
58
- """
59
- o_shape = rays_o.shape
60
- rays_o = rays_o.detach().reshape(-1, 3)
61
- rays_d = rays_d.detach().reshape(-1, 3)
62
-
63
-
64
- bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)]
65
- bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)]
66
- bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device)
67
- is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device)
68
-
69
- # Precompute inverse for stability.
70
- invdir = 1 / rays_d
71
- sign = (invdir < 0).long()
72
-
73
- # Intersect with YZ plane.
74
- tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
75
- tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
76
-
77
- # Intersect with XZ plane.
78
- tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
79
- tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
80
-
81
- # Resolve parallel rays.
82
- is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False
83
-
84
- # Use the shortest intersection.
85
- tmin = torch.max(tmin, tymin)
86
- tmax = torch.min(tmax, tymax)
87
-
88
- # Intersect with XY plane.
89
- tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
90
- tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
91
-
92
- # Resolve parallel rays.
93
- is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False
94
-
95
- # Use the shortest intersection.
96
- tmin = torch.max(tmin, tzmin)
97
- tmax = torch.min(tmax, tzmax)
98
-
99
- # Mark invalid.
100
- tmin[torch.logical_not(is_valid)] = -1
101
- tmax[torch.logical_not(is_valid)] = -2
102
-
103
- return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1)
104
-
105
-
106
- def linspace(start: torch.Tensor, stop: torch.Tensor, num: int):
107
- """
108
- Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.
109
- Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch.
110
- """
111
- # create a tensor of 'num' steps from 0 to 1
112
- steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1)
113
-
114
- # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings
115
- # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript
116
- # "cannot statically infer the expected size of a list in this contex", hence the code below
117
- for i in range(start.ndim):
118
- steps = steps.unsqueeze(-1)
119
-
120
- # the output starts at 'start' and increments until 'stop' in each dimension
121
- out = start[None] + steps * (stop - start)[None]
122
-
123
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lrm/models/rendering/utils/ray_marcher.py DELETED
@@ -1,73 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
7
- # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
8
- #
9
- # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
10
- # property and proprietary rights in and to this material, related
11
- # documentation and any modifications thereto. Any use, reproduction,
12
- # disclosure or distribution of this material and related documentation
13
- # without an express license agreement from NVIDIA CORPORATION or
14
- # its affiliates is strictly prohibited.
15
- #
16
- # Modified by Zexin He
17
- # The modifications are subject to the same license as the original.
18
-
19
-
20
- """
21
- The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths.
22
- Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!)
23
- """
24
-
25
- import torch
26
- import torch.nn as nn
27
-
28
-
29
- class MipRayMarcher2(nn.Module):
30
- def __init__(self, activation_factory):
31
- super().__init__()
32
- self.activation_factory = activation_factory
33
-
34
- def run_forward(self, colors, densities, depths, rendering_options):
35
-
36
- deltas = depths[:, :, 1:] - depths[:, :, :-1]
37
- colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2
38
- densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2
39
- depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2
40
-
41
-
42
-
43
- # using factory mode for better usability
44
- densities_mid = self.activation_factory(rendering_options)(densities_mid)
45
-
46
- density_delta = densities_mid * deltas
47
-
48
- alpha = 1 - torch.exp(-density_delta)
49
-
50
- alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2)
51
- weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1]
52
-
53
- composite_rgb = torch.sum(weights * colors_mid, -2)
54
- weight_total = weights.sum(2)
55
- composite_depth = torch.sum(weights * depths_mid, -2) / weight_total
56
-
57
- # clip the composite to min/max range of depths
58
- composite_depth = torch.nan_to_num(composite_depth, float('inf'))
59
- composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths))
60
-
61
- if rendering_options.get('white_back', False):
62
- composite_rgb = composite_rgb + 1 - weight_total
63
-
64
- # rendered value scale is 0-1, comment out original mipnerf scaling
65
- # composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1)
66
-
67
- return composite_rgb, composite_depth, weights
68
-
69
-
70
- def forward(self, colors, densities, depths, rendering_options):
71
- composite_rgb, composite_depth, weights = self.run_forward(colors, densities, depths, rendering_options)
72
-
73
- return composite_rgb, composite_depth, weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lrm/models/rendering/utils/ray_sampler_part.py DELETED
@@ -1,94 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
7
- # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
8
- #
9
- # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
10
- # property and proprietary rights in and to this material, related
11
- # documentation and any modifications thereto. Any use, reproduction,
12
- # disclosure or distribution of this material and related documentation
13
- # without an express license agreement from NVIDIA CORPORATION or
14
- # its affiliates is strictly prohibited.
15
- #
16
- # Modified by Zexin He
17
- # The modifications are subject to the same license as the original.
18
-
19
-
20
- """
21
- The ray sampler is a module that takes in camera matrices and resolution and batches of rays.
22
- Expects cam2world matrices that use the OpenCV camera coordinate system conventions.
23
- """
24
-
25
- import torch
26
-
27
- class RaySampler(torch.nn.Module):
28
- def __init__(self):
29
- super().__init__()
30
- self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None
31
-
32
-
33
- def forward(self, cam2world_matrix, intrinsics, render_size, crop_size, start_x, start_y):
34
- """
35
- Create batches of rays and return origins and directions.
36
-
37
- cam2world_matrix: (N, 4, 4)
38
- intrinsics: (N, 3, 3)
39
- render_size: int
40
-
41
- ray_origins: (N, M, 3)
42
- ray_dirs: (N, M, 2)
43
- """
44
-
45
- N, M = cam2world_matrix.shape[0], crop_size**2
46
- cam_locs_world = cam2world_matrix[:, :3, 3]
47
- fx = intrinsics[:, 0, 0]
48
- fy = intrinsics[:, 1, 1]
49
- cx = intrinsics[:, 0, 2]
50
- cy = intrinsics[:, 1, 2]
51
- sk = intrinsics[:, 0, 1]
52
-
53
- uv = torch.stack(torch.meshgrid(
54
- torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device),
55
- torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device),
56
- indexing='ij',
57
- ))
58
- if crop_size < render_size:
59
- patch_uv = []
60
- for i in range(cam2world_matrix.shape[0]):
61
- patch_uv.append(uv.clone()[None, :, start_y:start_y+crop_size, start_x:start_x+crop_size])
62
- uv = torch.cat(patch_uv, 0)
63
- uv = uv.flip(1).reshape(cam2world_matrix.shape[0], 2, -1).transpose(2, 1)
64
- else:
65
- uv = uv.flip(0).reshape(2, -1).transpose(1, 0)
66
- uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1)
67
- # uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1)
68
- # uv = uv.flip(1).reshape(cam2world_matrix.shape[0], 2, -1).transpose(2, 1)
69
- x_cam = uv[:, :, 0].view(N, -1) * (1./render_size) + (0.5/render_size)
70
- y_cam = uv[:, :, 1].view(N, -1) * (1./render_size) + (0.5/render_size)
71
- z_cam = torch.ones((N, M), device=cam2world_matrix.device)
72
-
73
- x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y_cam/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam
74
- y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam
75
-
76
- cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1).float()
77
-
78
- _opencv2blender = torch.tensor([
79
- [1, 0, 0, 0],
80
- [0, -1, 0, 0],
81
- [0, 0, -1, 0],
82
- [0, 0, 0, 1],
83
- ], dtype=torch.float32, device=cam2world_matrix.device).unsqueeze(0).repeat(N, 1, 1)
84
-
85
- # added float here
86
- cam2world_matrix = torch.bmm(cam2world_matrix.float(), _opencv2blender.float())
87
-
88
- world_rel_points = torch.bmm(cam2world_matrix.float(), cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3]
89
-
90
- ray_dirs = world_rel_points - cam_locs_world[:, None, :]
91
- ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2)
92
-
93
- ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1)
94
- return ray_origins, ray_dirs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lrm/models/rendering/utils/renderer.py DELETED
@@ -1,314 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
7
- # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
8
- #
9
- # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
10
- # property and proprietary rights in and to this material, related
11
- # documentation and any modifications thereto. Any use, reproduction,
12
- # disclosure or distribution of this material and related documentation
13
- # without an express license agreement from NVIDIA CORPORATION or
14
- # its affiliates is strictly prohibited.
15
- #
16
- # Modified by Zexin He
17
- # The modifications are subject to the same license as the original.
18
-
19
-
20
- """
21
- The renderer is a module that takes in rays, decides where to sample along each
22
- ray, and computes pixel colors using the volume rendering equation.
23
- """
24
-
25
- import torch
26
- import torch.nn as nn
27
- import torch.nn.functional as F
28
-
29
- from .ray_marcher import MipRayMarcher2
30
- from . import math_utils
31
-
32
- def generate_planes():
33
- """
34
- Defines planes by the three vectors that form the "axes" of the
35
- plane. Should work with arbitrary number of planes and planes of
36
- arbitrary orientation.
37
-
38
- Bugfix reference: https://github.com/NVlabs/eg3d/issues/67
39
- """
40
- return torch.tensor([[[1, 0, 0],
41
- [0, 1, 0],
42
- [0, 0, 1]],
43
- [[1, 0, 0],
44
- [0, 0, 1],
45
- [0, 1, 0]],
46
- [[0, 0, 1],
47
- [0, 1, 0],
48
- [1, 0, 0]]], dtype=torch.float32)
49
-
50
- def project_onto_planes(planes, coordinates):
51
- """
52
- Does a projection of a 3D point onto a batch of 2D planes,
53
- returning 2D plane coordinates.
54
-
55
- Takes plane axes of shape n_planes, 3, 3
56
- # Takes coordinates of shape N, M, 3
57
- # returns projections of shape N*n_planes, M, 2
58
- """
59
- N, M, C = coordinates.shape
60
- n_planes, _, _ = planes.shape
61
- coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3)
62
- inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3)
63
- coordinates = coordinates.to(inv_planes.device)
64
- projections = torch.bmm(coordinates, inv_planes)
65
- return projections[..., :2]
66
-
67
- def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None):
68
- assert padding_mode == 'zeros'
69
- N, n_planes, C, H, W = plane_features.shape
70
- _, M, _ = coordinates.shape
71
- plane_features = plane_features.view(N*n_planes, C, H, W)
72
-
73
- coordinates = (2/box_warp) * coordinates # add specific box bounds
74
- # half added here
75
- projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1)
76
- # removed float from projected_coordinates
77
- output_features = torch.nn.functional.grid_sample(plane_features.float(), projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
78
- return output_features
79
-
80
- def sample_from_3dgrid(grid, coordinates):
81
- """
82
- Expects coordinates in shape (batch_size, num_points_per_batch, 3)
83
- Expects grid in shape (1, channels, H, W, D)
84
- (Also works if grid has batch size)
85
- Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels)
86
- """
87
- batch_size, n_coords, n_dims = coordinates.shape
88
- sampled_features = torch.nn.functional.grid_sample(grid.expand(batch_size, -1, -1, -1, -1),
89
- coordinates.reshape(batch_size, 1, 1, -1, n_dims),
90
- mode='bilinear', padding_mode='zeros', align_corners=False)
91
- N, C, H, W, D = sampled_features.shape
92
- sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C)
93
- return sampled_features
94
-
95
- class ImportanceRenderer(torch.nn.Module):
96
- """
97
- Modified original version to filter out-of-box samples as TensoRF does.
98
-
99
- Reference:
100
- TensoRF: https://github.com/apchenstu/TensoRF/blob/main/models/tensorBase.py#L277
101
- """
102
- def __init__(self):
103
- super().__init__()
104
- self.activation_factory = self._build_activation_factory()
105
- self.ray_marcher = MipRayMarcher2(self.activation_factory)
106
- self.plane_axes = generate_planes()
107
-
108
- def _build_activation_factory(self):
109
- def activation_factory(options: dict):
110
- if options['clamp_mode'] == 'softplus':
111
- return lambda x: F.softplus(x - 1) # activation bias of -1 makes things initialize better
112
- else:
113
- assert False, "Renderer only supports `clamp_mode`=`softplus`!"
114
- return activation_factory
115
-
116
- def _forward_pass(self, depths: torch.Tensor, ray_directions: torch.Tensor, ray_origins: torch.Tensor,
117
- planes: torch.Tensor, decoder: nn.Module, rendering_options: dict):
118
- """
119
- Additional filtering is applied to filter out-of-box samples.
120
- Modifications made by Zexin He.
121
- """
122
-
123
- # context related variables
124
- batch_size, num_rays, samples_per_ray, _ = depths.shape
125
- device = planes.device
126
- depths = depths.to(device)
127
- ray_directions = ray_directions.to(device)
128
- ray_origins = ray_origins.to(device)
129
- # define sample points with depths
130
- sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3)
131
- sample_coordinates = (ray_origins.unsqueeze(-2) + depths * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3)
132
-
133
- # filter out-of-box samples
134
- mask_inbox = \
135
- (rendering_options['sampler_bbox_min'] <= sample_coordinates) & \
136
- (sample_coordinates <= rendering_options['sampler_bbox_max'])
137
- mask_inbox = mask_inbox.all(-1)
138
-
139
- # forward model according to all samples
140
- _out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options)
141
-
142
- # set out-of-box samples to zeros(rgb) & -inf(sigma)
143
- SAFE_GUARD = 3
144
- DATA_TYPE = _out['sigma'].dtype
145
- colors_pass = torch.zeros(batch_size, num_rays * samples_per_ray, 3, device=device, dtype=DATA_TYPE)
146
- densities_pass = torch.nan_to_num(torch.full((batch_size, num_rays * samples_per_ray, 1), -float('inf'), device=device, dtype=DATA_TYPE)) / SAFE_GUARD
147
- colors_pass[mask_inbox], densities_pass[mask_inbox] = _out['rgb'][mask_inbox], _out['sigma'][mask_inbox]
148
-
149
- # reshape back
150
- colors_pass = colors_pass.reshape(batch_size, num_rays, samples_per_ray, colors_pass.shape[-1])
151
- densities_pass = densities_pass.reshape(batch_size, num_rays, samples_per_ray, densities_pass.shape[-1])
152
-
153
- return colors_pass, densities_pass
154
-
155
- def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options):
156
- # self.plane_axes = self.plane_axes.to(ray_origins.device)
157
-
158
- if rendering_options['ray_start'] == rendering_options['ray_end'] == 'auto':
159
- ray_start, ray_end = math_utils.get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp'])
160
- is_ray_valid = ray_end > ray_start
161
- if torch.any(is_ray_valid).item():
162
- ray_start[~is_ray_valid] = ray_start[is_ray_valid].min()
163
- ray_end[~is_ray_valid] = ray_start[is_ray_valid].max()
164
- depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'])
165
- else:
166
- # Create stratified depth samples
167
- depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'])
168
-
169
- depths_coarse = depths_coarse.to(planes.device)
170
-
171
- # Coarse Pass
172
- colors_coarse, densities_coarse = self._forward_pass(
173
- depths=depths_coarse, ray_directions=ray_directions, ray_origins=ray_origins,
174
- planes=planes, decoder=decoder, rendering_options=rendering_options)
175
-
176
- # Fine Pass
177
- N_importance = rendering_options['depth_resolution_importance']
178
- if N_importance > 0:
179
- _, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options)
180
-
181
- depths_fine = self.sample_importance(depths_coarse, weights, N_importance)
182
-
183
- colors_fine, densities_fine = self._forward_pass(
184
- depths=depths_fine, ray_directions=ray_directions, ray_origins=ray_origins,
185
- planes=planes, decoder=decoder, rendering_options=rendering_options)
186
-
187
- all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse,
188
- depths_fine, colors_fine, densities_fine)
189
-
190
- # Aggregate
191
- rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options)
192
- else:
193
- rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options)
194
-
195
- return rgb_final, depth_final, weights.sum(2)
196
-
197
- def run_model(self, planes, decoder, sample_coordinates, sample_directions, options):
198
- plane_axes = self.plane_axes.to(planes.device)
199
- sampled_features = sample_from_planes(plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp'])
200
-
201
- out = decoder(sampled_features, sample_directions)
202
- if options.get('density_noise', 0) > 0:
203
- out['sigma'] += torch.randn_like(out['sigma']) * options['density_noise']
204
- return out
205
-
206
- def run_model_activated(self, planes, decoder, sample_coordinates, sample_directions, options):
207
- out = self.run_model(planes, decoder, sample_coordinates, sample_directions, options)
208
- out['sigma'] = self.activation_factory(options)(out['sigma'])
209
- return out
210
-
211
- def sort_samples(self, all_depths, all_colors, all_densities):
212
- _, indices = torch.sort(all_depths, dim=-2)
213
- all_depths = torch.gather(all_depths, -2, indices)
214
- all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
215
- all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
216
- return all_depths, all_colors, all_densities
217
-
218
- def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2):
219
- all_depths = torch.cat([depths1, depths2], dim = -2)
220
- all_colors = torch.cat([colors1, colors2], dim = -2)
221
- all_densities = torch.cat([densities1, densities2], dim = -2)
222
-
223
- _, indices = torch.sort(all_depths, dim=-2)
224
- all_depths = torch.gather(all_depths, -2, indices)
225
- all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
226
- all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
227
-
228
- return all_depths, all_colors, all_densities
229
-
230
- def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False):
231
- """
232
- Return depths of approximately uniformly spaced samples along rays.
233
- """
234
- N, M, _ = ray_origins.shape
235
- if disparity_space_sampling:
236
- depths_coarse = torch.linspace(0,
237
- 1,
238
- depth_resolution,
239
- device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
240
- depth_delta = 1/(depth_resolution - 1)
241
- depths_coarse += torch.rand_like(depths_coarse) * depth_delta
242
- depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse)
243
- else:
244
- if type(ray_start) == torch.Tensor:
245
- depths_coarse = math_utils.linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3)
246
- depth_delta = (ray_end - ray_start) / (depth_resolution - 1)
247
- depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None]
248
- else:
249
- depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
250
- depth_delta = (ray_end - ray_start)/(depth_resolution - 1)
251
- depths_coarse += torch.rand_like(depths_coarse) * depth_delta
252
-
253
- return depths_coarse
254
-
255
- def sample_importance(self, z_vals, weights, N_importance):
256
- """
257
- Return depths of importance sampled points along rays. See NeRF importance sampling for more.
258
- """
259
- with torch.no_grad():
260
- batch_size, num_rays, samples_per_ray, _ = z_vals.shape
261
-
262
- z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray)
263
- weights = weights.reshape(batch_size * num_rays, -1) # -1 to account for loss of 1 sample in MipRayMarcher
264
-
265
- # smooth weights
266
- weights = torch.nn.functional.max_pool1d(weights.unsqueeze(1).float(), 2, 1, padding=1)
267
- weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze()
268
- weights = weights + 0.01
269
-
270
- z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:])
271
- importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1],
272
- N_importance).detach().reshape(batch_size, num_rays, N_importance, 1)
273
- return importance_z_vals
274
-
275
- def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5):
276
- """
277
- Sample @N_importance samples from @bins with distribution defined by @weights.
278
- Inputs:
279
- bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2"
280
- weights: (N_rays, N_samples_)
281
- N_importance: the number of samples to draw from the distribution
282
- det: deterministic or not
283
- eps: a small number to prevent division by zero
284
- Outputs:
285
- samples: the sampled samples
286
- """
287
- N_rays, N_samples_ = weights.shape
288
- weights = weights + eps # prevent division by zero (don't do inplace op!)
289
- pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_)
290
- cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function
291
- cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1)
292
- # padded to 0~1 inclusive
293
-
294
- if det:
295
- u = torch.linspace(0, 1, N_importance, device=bins.device)
296
- u = u.expand(N_rays, N_importance)
297
- else:
298
- u = torch.rand(N_rays, N_importance, device=bins.device)
299
- u = u.contiguous()
300
-
301
- inds = torch.searchsorted(cdf, u, right=True)
302
- below = torch.clamp_min(inds-1, 0)
303
- above = torch.clamp_max(inds, N_samples_)
304
-
305
- inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance)
306
- cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2)
307
- bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2)
308
-
309
- denom = cdf_g[...,1]-cdf_g[...,0]
310
- denom[denom<eps] = 1 # denom equals 0 means a bin has weight 0, in which case it will not be sampled
311
- # anyway, therefore any value for it is fine (set to 1 here)
312
-
313
- samples = bins_g[...,0] + (u-cdf_g[...,0])/denom * (bins_g[...,1]-bins_g[...,0])
314
- return samples
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lrm/models/transformer.py DELETED
@@ -1,135 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
-
8
- import torch
9
- import torch.nn as nn
10
-
11
-
12
- class ModLN(nn.Module):
13
- """
14
- Modulation with adaLN.
15
-
16
- References:
17
- DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L101
18
- """
19
- def __init__(self, inner_dim: int, mod_dim: int, eps: float):
20
- super().__init__()
21
- self.norm = nn.LayerNorm(inner_dim, eps=eps)
22
- self.mlp = nn.Sequential(
23
- nn.SiLU(),
24
- nn.Linear(mod_dim, inner_dim * 2),
25
- )
26
-
27
- @staticmethod
28
- def modulate(x, shift, scale):
29
- # x: [N, L, D]
30
- # shift, scale: [N, D]
31
- return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
32
-
33
- def forward(self, x, cond):
34
- shift, scale = self.mlp(cond).chunk(2, dim=-1) # [N, D]
35
- return self.modulate(self.norm(x), shift, scale) # [N, L, D]
36
-
37
-
38
- class ConditionModulationBlock(nn.Module):
39
- """
40
- Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks.
41
- """
42
- # use attention from torch.nn.MultiHeadAttention
43
- # Block contains a cross-attention layer, a self-attention layer, and a MLP
44
- def __init__(self, inner_dim: int, cond_dim: int, mod_dim: int, num_heads: int, eps: float,
45
- attn_drop: float = 0., attn_bias: bool = False,
46
- mlp_ratio: float = 4., mlp_drop: float = 0.):
47
- super().__init__()
48
- self.norm1 = ModLN(inner_dim, mod_dim, eps)
49
- self.cross_attn = nn.MultiheadAttention(
50
- embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
51
- dropout=attn_drop, bias=attn_bias, batch_first=True)
52
- self.norm2 = ModLN(inner_dim, mod_dim, eps)
53
- self.self_attn = nn.MultiheadAttention(
54
- embed_dim=inner_dim, num_heads=num_heads,
55
- dropout=attn_drop, bias=attn_bias, batch_first=True)
56
- self.norm3 = ModLN(inner_dim, mod_dim, eps)
57
- self.mlp = nn.Sequential(
58
- nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
59
- nn.GELU(),
60
- nn.Dropout(mlp_drop),
61
- nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
62
- nn.Dropout(mlp_drop),
63
- )
64
-
65
- def forward(self, x, cond, mod):
66
- # x: [N, L, D]
67
- # cond: [N, L_cond, D_cond]
68
- # mod: [N, D_mod]
69
- x = x + self.cross_attn(self.norm1(x, mod), cond, cond, need_weights=False)[0]
70
- before_sa = self.norm2(x, mod)
71
- x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
72
- x = x + self.mlp(self.norm3(x, mod))
73
- return x
74
-
75
-
76
- class TriplaneTransformer(nn.Module):
77
- """
78
- Transformer with condition and modulation that generates a triplane representation.
79
-
80
- Reference:
81
- Timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L486
82
- """
83
- def __init__(self, inner_dim: int, image_feat_dim: int, camera_embed_dim: int,
84
- triplane_low_res: int, triplane_high_res: int, triplane_dim: int,
85
- num_layers: int, num_heads: int,
86
- eps: float = 1e-6):
87
- super().__init__()
88
-
89
- # attributes
90
- self.triplane_low_res = triplane_low_res
91
- self.triplane_high_res = triplane_high_res
92
- self.triplane_dim = triplane_dim
93
-
94
- # modules
95
- # initialize pos_embed with 1/sqrt(dim) * N(0, 1)
96
- self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, inner_dim) * (1. / inner_dim) ** 0.5)
97
- self.layers = nn.ModuleList([
98
- ConditionModulationBlock(
99
- inner_dim=inner_dim, cond_dim=image_feat_dim, mod_dim=camera_embed_dim, num_heads=num_heads, eps=eps)
100
- for _ in range(num_layers)
101
- ])
102
- self.norm = nn.LayerNorm(inner_dim, eps=eps)
103
- self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0)
104
-
105
- def forward(self, image_feats, camera_embeddings):
106
- # image_feats: [N, L_cond, D_cond]
107
- # camera_embeddings: [N, D_mod]
108
-
109
- assert image_feats.shape[0] == camera_embeddings.shape[0], \
110
- f"Mismatched batch size: {image_feats.shape[0]} vs {camera_embeddings.shape[0]}"
111
-
112
- N = image_feats.shape[0]
113
- H = W = self.triplane_low_res
114
- L = 3 * H * W
115
-
116
- x = self.pos_embed.repeat(N, 1, 1) # [N, L, D]
117
- for layer in self.layers:
118
- x = layer(x, image_feats, camera_embeddings)
119
- x = self.norm(x)
120
-
121
- # separate each plane and apply deconv
122
- x = x.view(N, 3, H, W, -1)
123
- x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W]
124
- x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W]
125
- x = self.deconv(x) # [3*N, D', H', W']
126
- x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W']
127
- x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W']
128
- x = x.contiguous()
129
-
130
- assert self.triplane_high_res == x.shape[-2], \
131
- f"Output triplane resolution does not match with expected: {x.shape[-2]} vs {self.triplane_high_res}"
132
- assert self.triplane_dim == x.shape[-3], \
133
- f"Output triplane dimension does not match with expected: {x.shape[-3]} vs {self.triplane_dim}"
134
-
135
- return x