guangkaixu commited on
Commit
70926c9
·
1 Parent(s): bc0cf88
app.py CHANGED
@@ -80,7 +80,7 @@ def process_depth(
80
  batch_size=1 if processing_res == 0 else 0,
81
  show_progress_bar=False,
82
  mode='depth',
83
- cmap='Spectral',
84
  )
85
 
86
  depth_pred = pipe_out.pred_np
@@ -252,7 +252,7 @@ def process_disparity(
252
  batch_size=1 if processing_res == 0 else 0,
253
  show_progress_bar=False,
254
  mode='disparity',
255
- cmap='Spectral',
256
  )
257
 
258
  disparity_pred = pipe_out.pred_np
 
80
  batch_size=1 if processing_res == 0 else 0,
81
  show_progress_bar=False,
82
  mode='depth',
83
+ color_map='Spectral',
84
  )
85
 
86
  depth_pred = pipe_out.pred_np
 
252
  batch_size=1 if processing_res == 0 else 0,
253
  show_progress_bar=False,
254
  mode='disparity',
255
+ color_map='Spectral',
256
  )
257
 
258
  disparity_pred = pipe_out.pred_np
genpercept/genpercept_pipeline.py CHANGED
@@ -41,6 +41,7 @@ from .util.image_util import (
41
 
42
  import matplotlib.pyplot as plt
43
  from genpercept.models.dpt_head import DPTNeckHeadForUnetAfterUpsampleIdentity
 
44
 
45
 
46
  class GenPerceptOutput(BaseOutput):
@@ -309,6 +310,9 @@ class GenPerceptPipeline(DiffusionPipeline):
309
  # Clip output range
310
  pipe_pred = pipe_pred.clip(0, 1)
311
 
 
 
 
312
  # Colorize
313
  if color_map is not None:
314
  assert self.mode == 'depth'
 
41
 
42
  import matplotlib.pyplot as plt
43
  from genpercept.models.dpt_head import DPTNeckHeadForUnetAfterUpsampleIdentity
44
+ from genpercept.util.image_util import process_normals
45
 
46
 
47
  class GenPerceptOutput(BaseOutput):
 
310
  # Clip output range
311
  pipe_pred = pipe_pred.clip(0, 1)
312
 
313
+ if mode == 'normal':
314
+ pred_np = process_normals(torch.from_numpy(pred_np)[None])
315
+
316
  # Colorize
317
  if color_map is not None:
318
  assert self.mode == 'depth'
genpercept/util/image_util.py CHANGED
@@ -25,6 +25,19 @@ import torch
25
  from torchvision.transforms import InterpolationMode
26
  from torchvision.transforms.functional import resize
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  def colorize_depth_maps(
30
  depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
 
25
  from torchvision.transforms import InterpolationMode
26
  from torchvision.transforms.functional import resize
27
 
28
+ def process_normals(input_images:torch.Tensor):
29
+ normal_preds = input_images
30
+ bsz, d, h, w = normal_preds.shape
31
+ normal_preds = normal_preds / (torch.norm(normal_preds, p=2, dim=1).unsqueeze(1)+1e-5)
32
+ phi = torch.atan2(normal_preds[:,1,:,:], normal_preds[:,0,:,:]).mean(dim=0)
33
+ theta = torch.atan2(torch.norm(normal_preds[:,:2,:,:], p=2, dim=1), normal_preds[:,2,:,:]).mean(dim=0)
34
+ normal_pred = torch.zeros((d,h,w)).to(normal_preds)
35
+ normal_pred[0,:,:] = torch.sin(theta) * torch.cos(phi)
36
+ normal_pred[1,:,:] = torch.sin(theta) * torch.sin(phi)
37
+ normal_pred[2,:,:] = torch.cos(theta)
38
+ angle_error = torch.acos(torch.clip(torch.cosine_similarity(normal_pred[None], normal_preds, dim=1),-0.999, 0.999))
39
+ normal_idx = torch.argmin(angle_error.reshape(bsz,-1).sum(-1))
40
+ return normal_preds[normal_idx]
41
 
42
  def colorize_depth_maps(
43
  depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None