michaelj commited on
Commit
d08dc68
·
verified ·
1 Parent(s): c851a4b

#positions = scale_tensor( # positions, (-self.cfg.radius, self.cfg.radius), (-1, 1) #)

Browse files
Files changed (1) hide show
  1. tsr/models/nerf_renderer.py +56 -4
tsr/models/nerf_renderer.py CHANGED
@@ -1,5 +1,5 @@
1
  from dataclasses import dataclass, field
2
- from typing import Dict
3
 
4
  import torch
5
  import torch.nn.functional as F
@@ -37,7 +37,59 @@ class TriplaneNeRFRenderer(BaseModule):
37
  chunk_size >= 0
38
  ), "chunk_size must be a non-negative integer (0 for no chunking)."
39
  self.chunk_size = chunk_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
 
 
41
  def query_triplane(
42
  self,
43
  decoder: torch.nn.Module,
@@ -49,9 +101,9 @@ class TriplaneNeRFRenderer(BaseModule):
49
 
50
  # positions in (-radius, radius)
51
  # normalized to (-1, 1) for grid sample
52
- positions = scale_tensor(
53
- positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
54
- )
55
 
56
  def _query_chunk(x):
57
  indices2D: torch.Tensor = torch.stack(
 
1
  from dataclasses import dataclass, field
2
+ from typing import Dict, Optional
3
 
4
  import torch
5
  import torch.nn.functional as F
 
37
  chunk_size >= 0
38
  ), "chunk_size must be a non-negative integer (0 for no chunking)."
39
  self.chunk_size = chunk_size
40
+ def make_step_grid(self,device, resolution: int, chunk_size: int = 32):
41
+ coords = torch.linspace(-1.0, 1.0, resolution, device = device)
42
+ x, y, z = torch.meshgrid(coords[0:chunk_size], coords, coords, indexing="ij")
43
+ x = x.reshape(-1, 1)
44
+ y = y.reshape(-1, 1)
45
+ z = z.reshape(-1, 1)
46
+ verts = torch.cat([x, y, z], dim = -1).view(-1, 3)
47
+ indices2D: torch.Tensor = torch.stack(
48
+ (verts[..., [0, 1]], verts[..., [0, 2]], verts[..., [1, 2]]),
49
+ dim=-3,
50
+ )
51
+ return indices2D
52
+
53
+ def query_triplane_volume_density(self, decoder: torch.nn.Module, triplane: torch.Tensor, resolution: int, sample_count: int = 1024 * 1024 * 4) -> torch.Tensor:
54
+ layer_count = sample_count // (resolution * resolution)
55
+ out_list = self.do_query_triplane_volume_density(decoder, triplane, resolution, layer_count)
56
+ return get_activation(self.cfg.density_activation)(
57
+ out_list.view([resolution * resolution * resolution, 1]) + self.cfg.density_bias
58
+ )
59
+ def do_query_triplane_volume_density(self, decoder: torch.nn.Module, triplane: torch.Tensor, resolution: int, layer_count: int) -> torch.Tensor:
60
+ step = 2.0 * layer_count / (resolution - 1)
61
+ indices2D = self.make_step_grid(triplane.device, resolution, layer_count)
62
+
63
+ out_list = torch.zeros([resolution, resolution * resolution, 1], device = triplane.device
64
+ )
65
+ for i in range(0, resolution, layer_count):
66
+ if i + layer_count > resolution:
67
+ layer_count = resolution - i
68
+ indices2D = indices2D[..., :resolution * resolution * layer_count, :]
69
+ density_step = self.sample_step_triplane_volume_density(decoder, triplane, indices2D)
70
+ # todo directly march cube
71
+ out_list[i:i + layer_count] = density_step.view([layer_count, resolution * resolution, 1])
72
+ #out_list.append(net_out['density'])
73
+ indices2D.transpose(1, 2)[0, 0] += step
74
+ indices2D.transpose(1, 2)[1, 0] += step
75
+
76
+ return out_list
77
+ def sample_step_triplane_volume_density(self, decoder, triplane, indices2D):
78
+ out: torch.Tensor = F.grid_sample(
79
+ rearrange(triplane, "Np Cp Hp Wp -> Np Cp Hp Wp", Np=3),
80
+ rearrange(indices2D, "Np N Nd -> Np () N Nd", Np=3),
81
+ align_corners=False,
82
+ mode="bilinear",
83
+ )
84
+ if self.cfg.feature_reduction == "concat":
85
+ out = rearrange(out, "Np Cp () N -> N (Np Cp)", Np=3)
86
+ elif self.cfg.feature_reduction == "mean":
87
+ out = reduce(out, "Np Cp () N -> N Cp", Np=3, reduction="mean")
88
+ else:
89
+ raise NotImplementedError
90
 
91
+ net_out: Dict[str, torch.Tensor] = decoder(out)
92
+ return net_out['density']
93
  def query_triplane(
94
  self,
95
  decoder: torch.nn.Module,
 
101
 
102
  # positions in (-radius, radius)
103
  # normalized to (-1, 1) for grid sample
104
+ #positions = scale_tensor(
105
+ # positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
106
+ #)
107
 
108
  def _query_chunk(x):
109
  indices2D: torch.Tensor = torch.stack(