#positions = scale_tensor( # positions, (-self.cfg.radius, self.cfg.radius), (-1, 1) #)
Browse files- tsr/models/nerf_renderer.py +56 -4
tsr/models/nerf_renderer.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from dataclasses import dataclass, field
|
2 |
-
from typing import Dict
|
3 |
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
@@ -37,7 +37,59 @@ class TriplaneNeRFRenderer(BaseModule):
|
|
37 |
chunk_size >= 0
|
38 |
), "chunk_size must be a non-negative integer (0 for no chunking)."
|
39 |
self.chunk_size = chunk_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
|
|
|
|
41 |
def query_triplane(
|
42 |
self,
|
43 |
decoder: torch.nn.Module,
|
@@ -49,9 +101,9 @@ class TriplaneNeRFRenderer(BaseModule):
|
|
49 |
|
50 |
# positions in (-radius, radius)
|
51 |
# normalized to (-1, 1) for grid sample
|
52 |
-
positions = scale_tensor(
|
53 |
-
|
54 |
-
)
|
55 |
|
56 |
def _query_chunk(x):
|
57 |
indices2D: torch.Tensor = torch.stack(
|
|
|
1 |
from dataclasses import dataclass, field
|
2 |
+
from typing import Dict, Optional
|
3 |
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
|
|
37 |
chunk_size >= 0
|
38 |
), "chunk_size must be a non-negative integer (0 for no chunking)."
|
39 |
self.chunk_size = chunk_size
|
40 |
+
def make_step_grid(self,device, resolution: int, chunk_size: int = 32):
|
41 |
+
coords = torch.linspace(-1.0, 1.0, resolution, device = device)
|
42 |
+
x, y, z = torch.meshgrid(coords[0:chunk_size], coords, coords, indexing="ij")
|
43 |
+
x = x.reshape(-1, 1)
|
44 |
+
y = y.reshape(-1, 1)
|
45 |
+
z = z.reshape(-1, 1)
|
46 |
+
verts = torch.cat([x, y, z], dim = -1).view(-1, 3)
|
47 |
+
indices2D: torch.Tensor = torch.stack(
|
48 |
+
(verts[..., [0, 1]], verts[..., [0, 2]], verts[..., [1, 2]]),
|
49 |
+
dim=-3,
|
50 |
+
)
|
51 |
+
return indices2D
|
52 |
+
|
53 |
+
def query_triplane_volume_density(self, decoder: torch.nn.Module, triplane: torch.Tensor, resolution: int, sample_count: int = 1024 * 1024 * 4) -> torch.Tensor:
|
54 |
+
layer_count = sample_count // (resolution * resolution)
|
55 |
+
out_list = self.do_query_triplane_volume_density(decoder, triplane, resolution, layer_count)
|
56 |
+
return get_activation(self.cfg.density_activation)(
|
57 |
+
out_list.view([resolution * resolution * resolution, 1]) + self.cfg.density_bias
|
58 |
+
)
|
59 |
+
def do_query_triplane_volume_density(self, decoder: torch.nn.Module, triplane: torch.Tensor, resolution: int, layer_count: int) -> torch.Tensor:
|
60 |
+
step = 2.0 * layer_count / (resolution - 1)
|
61 |
+
indices2D = self.make_step_grid(triplane.device, resolution, layer_count)
|
62 |
+
|
63 |
+
out_list = torch.zeros([resolution, resolution * resolution, 1], device = triplane.device
|
64 |
+
)
|
65 |
+
for i in range(0, resolution, layer_count):
|
66 |
+
if i + layer_count > resolution:
|
67 |
+
layer_count = resolution - i
|
68 |
+
indices2D = indices2D[..., :resolution * resolution * layer_count, :]
|
69 |
+
density_step = self.sample_step_triplane_volume_density(decoder, triplane, indices2D)
|
70 |
+
# todo directly march cube
|
71 |
+
out_list[i:i + layer_count] = density_step.view([layer_count, resolution * resolution, 1])
|
72 |
+
#out_list.append(net_out['density'])
|
73 |
+
indices2D.transpose(1, 2)[0, 0] += step
|
74 |
+
indices2D.transpose(1, 2)[1, 0] += step
|
75 |
+
|
76 |
+
return out_list
|
77 |
+
def sample_step_triplane_volume_density(self, decoder, triplane, indices2D):
|
78 |
+
out: torch.Tensor = F.grid_sample(
|
79 |
+
rearrange(triplane, "Np Cp Hp Wp -> Np Cp Hp Wp", Np=3),
|
80 |
+
rearrange(indices2D, "Np N Nd -> Np () N Nd", Np=3),
|
81 |
+
align_corners=False,
|
82 |
+
mode="bilinear",
|
83 |
+
)
|
84 |
+
if self.cfg.feature_reduction == "concat":
|
85 |
+
out = rearrange(out, "Np Cp () N -> N (Np Cp)", Np=3)
|
86 |
+
elif self.cfg.feature_reduction == "mean":
|
87 |
+
out = reduce(out, "Np Cp () N -> N Cp", Np=3, reduction="mean")
|
88 |
+
else:
|
89 |
+
raise NotImplementedError
|
90 |
|
91 |
+
net_out: Dict[str, torch.Tensor] = decoder(out)
|
92 |
+
return net_out['density']
|
93 |
def query_triplane(
|
94 |
self,
|
95 |
decoder: torch.nn.Module,
|
|
|
101 |
|
102 |
# positions in (-radius, radius)
|
103 |
# normalized to (-1, 1) for grid sample
|
104 |
+
#positions = scale_tensor(
|
105 |
+
# positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
|
106 |
+
#)
|
107 |
|
108 |
def _query_chunk(x):
|
109 |
indices2D: torch.Tensor = torch.stack(
|