jadechoghari commited on
Commit
7841a02
·
verified ·
1 Parent(s): 6574f13

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +50 -48
modeling.py CHANGED
@@ -82,55 +82,57 @@ class LRMGenerator(PreTrainedModel):
82
  assert camera_embeddings.shape[-1] == self.camera_embed_dim, \
83
  f"Feature dimension mismatch: {camera_embeddings.shape[-1]} vs {self.camera_embed_dim}"
84
 
85
- # transformer generating planes
86
- planes = self.transformer(image_feats, camera_embeddings)
87
- assert planes.shape[0] == N, "Batch size mismatch for planes"
88
- assert planes.shape[1] == 3, "Planes should have 3 channels"
89
-
90
- # Generate the mesh
91
- if export_mesh:
92
- import mcubes
93
- import trimesh
94
- grid_out = self.synthesizer.forward_grid(planes=planes, grid_size=mesh_size)
95
- vtx, faces = mcubes.marching_cubes(grid_out['sigma'].float().squeeze(0).squeeze(-1).cpu().numpy(), 1.0)
96
- vtx = vtx / (mesh_size - 1) * 2 - 1
97
- vtx_tensor = torch.tensor(vtx, dtype=torch.float32, device=image.device).unsqueeze(0)
98
- vtx_colors = self.synthesizer.forward_points(planes, vtx_tensor)['rgb'].float().squeeze(0).cpu().numpy()
99
- vtx_colors = (vtx_colors * 255).astype(np.uint8)
100
- mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors)
101
-
102
- mesh_path = "awesome_mesh.obj"
103
- mesh.export(mesh_path, 'obj')
104
 
105
- return planes, mesh_path
106
-
107
- # Generate video
108
- if export_video:
109
- render_cameras = self._default_render_cameras(batch_size=N).to(image.device)
110
-
111
- frames = []
112
- chunk_size = 1 # Adjust chunk size as needed
113
- for i in range(0, render_cameras.shape[1], chunk_size):
114
- frame_chunk = self.synthesizer(
115
- planes,
116
- render_cameras[:, i:i + chunk_size],
117
- render_size,
118
- render_size,
119
- 0,
120
- 0
121
- )
122
- frames.append(frame_chunk['images_rgb'])
123
-
124
- frames = torch.cat(frames, dim=1)
125
- frames = (frames.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
126
-
127
- # Save video
128
- video_path = "awesome_video.mp4"
129
- imageio.mimwrite(video_path, frames, fps=fps)
130
-
131
- return planes, video_path
132
-
133
- return planes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  # Copied from https://github.com/facebookresearch/vfusion3d/blob/main/lrm/cam_utils.py
136
  # and https://github.com/facebookresearch/vfusion3d/blob/main/lrm/inferrer.py
 
82
  assert camera_embeddings.shape[-1] == self.camera_embed_dim, \
83
  f"Feature dimension mismatch: {camera_embeddings.shape[-1]} vs {self.camera_embed_dim}"
84
 
85
+ with torch.no_grad():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
+ # transformer generating planes
88
+ planes = self.transformer(image_feats, camera_embeddings)
89
+ assert planes.shape[0] == N, "Batch size mismatch for planes"
90
+ assert planes.shape[1] == 3, "Planes should have 3 channels"
91
+
92
+ # Generate the mesh
93
+ if export_mesh:
94
+ import mcubes
95
+ import trimesh
96
+ grid_out = self.synthesizer.forward_grid(planes=planes, grid_size=mesh_size)
97
+ vtx, faces = mcubes.marching_cubes(grid_out['sigma'].float().squeeze(0).squeeze(-1).cpu().numpy(), 1.0)
98
+ vtx = vtx / (mesh_size - 1) * 2 - 1
99
+ vtx_tensor = torch.tensor(vtx, dtype=torch.float32, device=image.device).unsqueeze(0)
100
+ vtx_colors = self.synthesizer.forward_points(planes, vtx_tensor)['rgb'].float().squeeze(0).cpu().numpy()
101
+ vtx_colors = (vtx_colors * 255).astype(np.uint8)
102
+ mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors)
103
+
104
+ mesh_path = "awesome_mesh.obj"
105
+ mesh.export(mesh_path, 'obj')
106
+
107
+ return planes, mesh_path
108
+
109
+ # Generate video
110
+ if export_video:
111
+ render_cameras = self._default_render_cameras(batch_size=N).to(image.device)
112
+
113
+ frames = []
114
+ chunk_size = 1 # Adjust chunk size as needed
115
+ for i in range(0, render_cameras.shape[1], chunk_size):
116
+ frame_chunk = self.synthesizer(
117
+ planes,
118
+ render_cameras[:, i:i + chunk_size],
119
+ render_size,
120
+ render_size,
121
+ 0,
122
+ 0
123
+ )
124
+ frames.append(frame_chunk['images_rgb'])
125
+
126
+ frames = torch.cat(frames, dim=1)
127
+ frames = (frames.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
128
+
129
+ # Save video
130
+ video_path = "awesome_video.mp4"
131
+ imageio.mimwrite(video_path, frames, fps=fps)
132
+
133
+ return planes, video_path
134
+
135
+ return planes
136
 
137
  # Copied from https://github.com/facebookresearch/vfusion3d/blob/main/lrm/cam_utils.py
138
  # and https://github.com/facebookresearch/vfusion3d/blob/main/lrm/inferrer.py