Wuvin commited on
Commit
94285bf
·
1 Parent(s): 3087b03

use pytorch3d to render, instead of nvdiffrast

Browse files
gradio_app/gradio_3dgen.py CHANGED
@@ -10,13 +10,8 @@ from scripts.refine_lr_to_sr import run_sr_fast
10
  from scripts.utils import save_glb_and_video
11
  from scripts.multiview_inference import geo_reconstruct
12
 
13
-
14
- import nvdiffrast.torch as dr
15
- dr.RasterizeGLContext(output_db=False)
16
  @spaces.GPU
17
  def generate3dv2(preview_img, input_processing, seed, render_video=True, do_refine=True, expansion_weight=0.1, init_type="std"):
18
- dr.RasterizeGLContext(output_db=False) # BUG: cuda_runtime_api.h: No such file or directory
19
-
20
  if preview_img is None:
21
  raise gr.Error("preview_img is none")
22
  if isinstance(preview_img, str):
 
10
  from scripts.utils import save_glb_and_video
11
  from scripts.multiview_inference import geo_reconstruct
12
 
 
 
 
13
  @spaces.GPU
14
  def generate3dv2(preview_img, input_processing, seed, render_video=True, do_refine=True, expansion_weight=0.1, init_type="std"):
 
 
15
  if preview_img is None:
16
  raise gr.Error("preview_img is none")
17
  if isinstance(preview_img, str):
mesh_reconstruction/recon.py CHANGED
@@ -6,14 +6,14 @@ from typing import List
6
  from mesh_reconstruction.remesh import calc_vertex_normals
7
  from mesh_reconstruction.opt import MeshOptimizer
8
  from mesh_reconstruction.func import make_star_cameras_orthographic
9
- from mesh_reconstruction.render import NormalsRenderer
10
  from scripts.utils import to_py3d_mesh, init_target
11
 
12
  def reconstruct_stage1(pils: List[Image.Image], steps=100, vertices=None, faces=None, start_edge_len=0.15, end_edge_len=0.005, decay=0.995, return_mesh=True, loss_expansion_weight=0.1, gain=0.1):
13
  vertices, faces = vertices.to("cuda"), faces.to("cuda")
14
  assert len(pils) == 4
15
  mv,proj = make_star_cameras_orthographic(4, 1)
16
- renderer = NormalsRenderer(mv,proj,list(pils[0].size))
17
 
18
  target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s
19
  # 1. no rotate
 
6
  from mesh_reconstruction.remesh import calc_vertex_normals
7
  from mesh_reconstruction.opt import MeshOptimizer
8
  from mesh_reconstruction.func import make_star_cameras_orthographic
9
+ from mesh_reconstruction.render import NormalsRenderer, Pytorch3DNormalsRenderer
10
  from scripts.utils import to_py3d_mesh, init_target
11
 
12
  def reconstruct_stage1(pils: List[Image.Image], steps=100, vertices=None, faces=None, start_edge_len=0.15, end_edge_len=0.005, decay=0.995, return_mesh=True, loss_expansion_weight=0.1, gain=0.1):
13
  vertices, faces = vertices.to("cuda"), faces.to("cuda")
14
  assert len(pils) == 4
15
  mv,proj = make_star_cameras_orthographic(4, 1)
16
+ renderer = Pytorch3DNormalsRenderer(mv,proj,list(pils[0].size))
17
 
18
  target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s
19
  # 1. no rotate
mesh_reconstruction/refine.py CHANGED
@@ -5,7 +5,7 @@ from typing import List
5
  from mesh_reconstruction.remesh import calc_vertex_normals
6
  from mesh_reconstruction.opt import MeshOptimizer
7
  from mesh_reconstruction.func import make_star_cameras_orthographic
8
- from mesh_reconstruction.render import NormalsRenderer
9
  from scripts.project_mesh import multiview_color_projection, get_cameras_list
10
  from scripts.utils import to_py3d_mesh, from_py3d_mesh, init_target
11
 
@@ -18,7 +18,7 @@ def run_mesh_refine(vertices, faces, pils: List[Image.Image], steps=100, start_e
18
 
19
  assert len(pils) == 4
20
  mv,proj = make_star_cameras_orthographic(4, 1)
21
- renderer = NormalsRenderer(mv,proj,list(pils[0].size))
22
 
23
  target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s
24
  # 1. no rotate
 
5
  from mesh_reconstruction.remesh import calc_vertex_normals
6
  from mesh_reconstruction.opt import MeshOptimizer
7
  from mesh_reconstruction.func import make_star_cameras_orthographic
8
+ from mesh_reconstruction.render import NormalsRenderer, Pytorch3DNormalsRenderer
9
  from scripts.project_mesh import multiview_color_projection, get_cameras_list
10
  from scripts.utils import to_py3d_mesh, from_py3d_mesh, init_target
11
 
 
18
 
19
  assert len(pils) == 4
20
  mv,proj = make_star_cameras_orthographic(4, 1)
21
+ renderer = Pytorch3DNormalsRenderer(mv,proj,list(pils[0].size))
22
 
23
  target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s
24
  # 1. no rotate
mesh_reconstruction/render.py CHANGED
@@ -49,3 +49,121 @@ class NormalsRenderer:
49
  col = torch.concat((col,alpha),dim=-1) #C,H,W,4
50
  col = dr.antialias(col, rast_out, vertices_clip, faces) #C,H,W,4
51
  return col #C,H,W,4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  col = torch.concat((col,alpha),dim=-1) #C,H,W,4
50
  col = dr.antialias(col, rast_out, vertices_clip, faces) #C,H,W,4
51
  return col #C,H,W,4
52
+
53
+ from pytorch3d.structures import Meshes
54
+ from pytorch3d.renderer.mesh.shader import ShaderBase
55
+ from pytorch3d.renderer import (
56
+ RasterizationSettings,
57
+ MeshRendererWithFragments,
58
+ TexturesVertex,
59
+ MeshRasterizer,
60
+ BlendParams,
61
+ FoVOrthographicCameras,
62
+ look_at_view_transform,
63
+ hard_rgb_blend,
64
+ )
65
+
66
+ class VertexColorShader(ShaderBase):
67
+ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
68
+ blend_params = kwargs.get("blend_params", self.blend_params)
69
+ texels = meshes.sample_textures(fragments)
70
+ return hard_rgb_blend(texels, fragments, blend_params)
71
+
72
+ def render_mesh_vertex_color(mesh, cameras, H, W, blur_radius=0.0, faces_per_pixel=1, bkgd=(0., 0., 0.), dtype=torch.float32, device="cuda"):
73
+ if len(mesh) != len(cameras):
74
+ if len(cameras) % len(mesh) == 0:
75
+ mesh = mesh.extend(len(cameras))
76
+ else:
77
+ raise NotImplementedError()
78
+
79
+ # render requires everything in float16 or float32
80
+ input_dtype = dtype
81
+ blend_params = BlendParams(1e-4, 1e-4, bkgd)
82
+
83
+ # Define the settings for rasterization and shading
84
+ raster_settings = RasterizationSettings(
85
+ image_size=(H, W),
86
+ blur_radius=blur_radius,
87
+ faces_per_pixel=faces_per_pixel,
88
+ clip_barycentric_coords=True,
89
+ bin_size=None,
90
+ max_faces_per_bin=500000,
91
+ )
92
+
93
+ # Create a renderer by composing a rasterizer and a shader
94
+ # We simply render vertex colors through the custom VertexColorShader (no lighting, materials are used)
95
+ renderer = MeshRendererWithFragments(
96
+ rasterizer=MeshRasterizer(
97
+ cameras=cameras,
98
+ raster_settings=raster_settings
99
+ ),
100
+ shader=VertexColorShader(
101
+ device=device,
102
+ cameras=cameras,
103
+ blend_params=blend_params
104
+ )
105
+ )
106
+
107
+ # render RGB and depth, get mask
108
+ with torch.autocast(dtype=input_dtype, device_type=torch.device(device).type):
109
+ images, _ = renderer(mesh)
110
+ return images # BHW4
111
+
112
+ class Pytorch3DNormalsRenderer:
113
+ def __init__(self, cameras, image_size, device):
114
+ self.cameras = cameras.to(device)
115
+ self._image_size = image_size
116
+ self.device = device
117
+
118
+ def render(self,
119
+ vertices: torch.Tensor, #V,3 float
120
+ normals: torch.Tensor, #V,3 float in [-1, 1]
121
+ faces: torch.Tensor, #F,3 long
122
+ ) ->torch.Tensor: #C,H,W,4
123
+ mesh = Meshes(verts=[vertices], faces=[faces], textures=TexturesVertex(verts_features=[(normals + 1) / 2])).to(self.device)
124
+ return render_mesh_vertex_color(mesh, self.cameras, self._image_size[0], self._image_size[1], device=self.device)
125
+
126
+ def get_camera(R, T, focal_length=1 / (2**0.5)):
127
+ focal_length = 1 / focal_length
128
+ camera = FoVOrthographicCameras(device=R.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length)
129
+ return camera
130
+
131
+ def make_star_cameras_orthographic_py3d(azim_list, device, focal=2/1.35, dist=1.1):
132
+ R, T = look_at_view_transform(dist, 0, azim_list)
133
+ focal_length = 1 / focal
134
+ return FoVOrthographicCameras(device=R.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length).to(device)
135
+
136
+ def save_tensor_to_img(tensor, save_dir):
137
+ from PIL import Image
138
+ import numpy as np
139
+ for idx, img in enumerate(tensor):
140
+ img = img[..., :3].cpu().numpy()
141
+ img = (img * 255).astype(np.uint8)
142
+ img = Image.fromarray(img)
143
+ img.save(save_dir + f"{idx}.png")
144
+
145
+ if __name__ == "__main__":
146
+ import sys
147
+ import os
148
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
149
+ from mesh_reconstruction.func import make_star_cameras_orthographic
150
+ cameras = make_star_cameras_orthographic_py3d([0, 270, 180, 90], device="cuda", focal=1., dist=4.0)
151
+ mv,proj = make_star_cameras_orthographic(4, 1)
152
+ resolution = 1024
153
+ renderer1 = NormalsRenderer(mv,proj, [resolution,resolution], device="cuda")
154
+ renderer2 = Pytorch3DNormalsRenderer(cameras, [resolution,resolution], device="cuda")
155
+ vertices = torch.tensor([[0,0,0],[0,0,1],[0,1,0],[1,0,0]], device="cuda", dtype=torch.float32)
156
+ normals = torch.tensor([[-1,-1,-1],[1,-1,-1],[-1,-1,1],[-1,1,-1]], device="cuda", dtype=torch.float32)
157
+ faces = torch.tensor([[0,1,2],[0,1,3],[0,2,3],[1,2,3]], device="cuda", dtype=torch.long)
158
+
159
+ import time
160
+ t0 = time.time()
161
+ r1 = renderer1.render(vertices, normals, faces)
162
+ print("time r1:", time.time() - t0)
163
+
164
+ t0 = time.time()
165
+ r2 = renderer2.render(vertices, normals, faces)
166
+ print("time r2:", time.time() - t0)
167
+
168
+ for i in range(4):
169
+ print((r1[i]-r2[i]).abs().mean(), (r1[i]+r2[i]).abs().mean())
scripts/project_mesh.py CHANGED
@@ -13,17 +13,6 @@ from pytorch3d.renderer import (
13
  )
14
  from pytorch3d.renderer import MeshRasterizer
15
 
16
- def get_camera(world_to_cam, fov_in_degrees=60, focal_length=1 / (2**0.5), cam_type='fov'):
17
- # pytorch3d expects transforms as row-vectors, so flip rotation: https://github.com/facebookresearch/pytorch3d/issues/1183
18
- R = world_to_cam[:3, :3].t()[None, ...]
19
- T = world_to_cam[:3, 3][None, ...]
20
- if cam_type == 'fov':
21
- camera = FoVPerspectiveCameras(device=world_to_cam.device, R=R, T=T, fov=fov_in_degrees, degrees=True)
22
- else:
23
- focal_length = 1 / focal_length
24
- camera = FoVOrthographicCameras(device=world_to_cam.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length)
25
- return camera
26
-
27
  def render_pix2faces_py3d(meshes, cameras, H=512, W=512, blur_radius=0.0, faces_per_pixel=1):
28
  """
29
  Renders pix2face of visible faces.
@@ -98,11 +87,11 @@ class Pix2FacesRenderer:
98
  pix2faces_renderer = None
99
 
100
  def get_visible_faces(meshes: Meshes, cameras: CamerasBase, resolution=1024):
101
- global pix2faces_renderer
102
- if pix2faces_renderer is None:
103
- pix2faces_renderer = Pix2FacesRenderer()
104
- # pix_to_face = render_pix2faces_py3d(meshes, cameras, H=resolution, W=resolution)['pix_to_face']
105
- pix_to_face = pix2faces_renderer.render_pix2faces_nvdiff(meshes, cameras, H=resolution, W=resolution)
106
 
107
  unique_faces = torch.unique(pix_to_face.flatten())
108
  unique_faces = unique_faces[unique_faces != -1]
@@ -313,12 +302,19 @@ def multiview_color_projection(meshes: Meshes, image_list: List[Image.Image], ca
313
  del meshes
314
  return ret_mesh
315
 
 
 
 
 
 
 
 
 
316
  def get_cameras_list(azim_list, device, focal=2/1.35, dist=1.1):
317
  ret = []
318
  for azim in azim_list:
319
  R, T = look_at_view_transform(dist, 0, azim)
320
- w2c = torch.cat([R[0].T, T[0, :, None]], dim=1)
321
- cameras: OrthographicCameras = get_camera(w2c, focal_length=focal, cam_type='orthogonal').to(device)
322
  ret.append(cameras)
323
  return ret
324
 
 
13
  )
14
  from pytorch3d.renderer import MeshRasterizer
15
 
 
 
 
 
 
 
 
 
 
 
 
16
  def render_pix2faces_py3d(meshes, cameras, H=512, W=512, blur_radius=0.0, faces_per_pixel=1):
17
  """
18
  Renders pix2face of visible faces.
 
87
  pix2faces_renderer = None
88
 
89
  def get_visible_faces(meshes: Meshes, cameras: CamerasBase, resolution=1024):
90
+ # global pix2faces_renderer
91
+ # if pix2faces_renderer is None:
92
+ # pix2faces_renderer = Pix2FacesRenderer()
93
+ pix_to_face = render_pix2faces_py3d(meshes, cameras, H=resolution, W=resolution)['pix_to_face']
94
+ # pix_to_face = pix2faces_renderer.render_pix2faces_nvdiff(meshes, cameras, H=resolution, W=resolution)
95
 
96
  unique_faces = torch.unique(pix_to_face.flatten())
97
  unique_faces = unique_faces[unique_faces != -1]
 
302
  del meshes
303
  return ret_mesh
304
 
305
+ def get_camera(R, T, fov_in_degrees=60, focal_length=1 / (2**0.5), cam_type='fov'):
306
+ if cam_type == 'fov':
307
+ camera = FoVPerspectiveCameras(device=R.device, R=R, T=T, fov=fov_in_degrees, degrees=True)
308
+ else:
309
+ focal_length = 1 / focal_length
310
+ camera = FoVOrthographicCameras(device=R.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length)
311
+ return camera
312
+
313
  def get_cameras_list(azim_list, device, focal=2/1.35, dist=1.1):
314
  ret = []
315
  for azim in azim_list:
316
  R, T = look_at_view_transform(dist, 0, azim)
317
+ cameras: OrthographicCameras = get_camera(R, T, focal_length=focal, cam_type='orthogonal').to(device)
 
318
  ret.append(cameras)
319
  return ret
320