JUGGHM commited on
Commit
6b09c9a
1 Parent(s): 81610d3

Update mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py

Browse files
mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py CHANGED
@@ -792,7 +792,8 @@ class RAFTDepthNormalDPT5(nn.Module):
792
  self.relu = nn.ReLU(inplace=True)
793
 
794
  def get_bins(self, bins_num):
795
- depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device="cuda")
 
796
  depth_bins_vec = torch.exp(depth_bins_vec)
797
  return depth_bins_vec
798
 
@@ -847,7 +848,7 @@ class RAFTDepthNormalDPT5(nn.Module):
847
  return norm_normalize(torch.cat([normal_out, confidence], dim=1))
848
  #return norm_normalize(torch.cat([normal_out, confidence], dim=1).float())
849
 
850
- def create_mesh_grid(self, height, width, batch, device="cuda", set_buffer=True):
851
  y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=device),
852
  torch.arange(0, width, dtype=torch.float32, device=device)], indexing='ij')
853
  meshgrid = torch.stack((x, y))
 
792
  self.relu = nn.ReLU(inplace=True)
793
 
794
  def get_bins(self, bins_num):
795
+ #depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device="cuda")
796
+ depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device="cpu")
797
  depth_bins_vec = torch.exp(depth_bins_vec)
798
  return depth_bins_vec
799
 
 
848
  return norm_normalize(torch.cat([normal_out, confidence], dim=1))
849
  #return norm_normalize(torch.cat([normal_out, confidence], dim=1).float())
850
 
851
+ def create_mesh_grid(self, height, width, batch, device="cpu", set_buffer=True):
852
  y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=device),
853
  torch.arange(0, width, dtype=torch.float32, device=device)], indexing='ij')
854
  meshgrid = torch.stack((x, y))