Update modeling.py
Browse files- modeling.py +50 -48
modeling.py
CHANGED
@@ -82,55 +82,57 @@ class LRMGenerator(PreTrainedModel):
|
|
82 |
assert camera_embeddings.shape[-1] == self.camera_embed_dim, \
|
83 |
f"Feature dimension mismatch: {camera_embeddings.shape[-1]} vs {self.camera_embed_dim}"
|
84 |
|
85 |
-
|
86 |
-
planes = self.transformer(image_feats, camera_embeddings)
|
87 |
-
assert planes.shape[0] == N, "Batch size mismatch for planes"
|
88 |
-
assert planes.shape[1] == 3, "Planes should have 3 channels"
|
89 |
-
|
90 |
-
# Generate the mesh
|
91 |
-
if export_mesh:
|
92 |
-
import mcubes
|
93 |
-
import trimesh
|
94 |
-
grid_out = self.synthesizer.forward_grid(planes=planes, grid_size=mesh_size)
|
95 |
-
vtx, faces = mcubes.marching_cubes(grid_out['sigma'].float().squeeze(0).squeeze(-1).cpu().numpy(), 1.0)
|
96 |
-
vtx = vtx / (mesh_size - 1) * 2 - 1
|
97 |
-
vtx_tensor = torch.tensor(vtx, dtype=torch.float32, device=image.device).unsqueeze(0)
|
98 |
-
vtx_colors = self.synthesizer.forward_points(planes, vtx_tensor)['rgb'].float().squeeze(0).cpu().numpy()
|
99 |
-
vtx_colors = (vtx_colors * 255).astype(np.uint8)
|
100 |
-
mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors)
|
101 |
-
|
102 |
-
mesh_path = "awesome_mesh.obj"
|
103 |
-
mesh.export(mesh_path, 'obj')
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
#
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
# Copied from https://github.com/facebookresearch/vfusion3d/blob/main/lrm/cam_utils.py
|
136 |
# and https://github.com/facebookresearch/vfusion3d/blob/main/lrm/inferrer.py
|
|
|
82 |
assert camera_embeddings.shape[-1] == self.camera_embed_dim, \
|
83 |
f"Feature dimension mismatch: {camera_embeddings.shape[-1]} vs {self.camera_embed_dim}"
|
84 |
|
85 |
+
with torch.no_grad():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
+
# transformer generating planes
|
88 |
+
planes = self.transformer(image_feats, camera_embeddings)
|
89 |
+
assert planes.shape[0] == N, "Batch size mismatch for planes"
|
90 |
+
assert planes.shape[1] == 3, "Planes should have 3 channels"
|
91 |
+
|
92 |
+
# Generate the mesh
|
93 |
+
if export_mesh:
|
94 |
+
import mcubes
|
95 |
+
import trimesh
|
96 |
+
grid_out = self.synthesizer.forward_grid(planes=planes, grid_size=mesh_size)
|
97 |
+
vtx, faces = mcubes.marching_cubes(grid_out['sigma'].float().squeeze(0).squeeze(-1).cpu().numpy(), 1.0)
|
98 |
+
vtx = vtx / (mesh_size - 1) * 2 - 1
|
99 |
+
vtx_tensor = torch.tensor(vtx, dtype=torch.float32, device=image.device).unsqueeze(0)
|
100 |
+
vtx_colors = self.synthesizer.forward_points(planes, vtx_tensor)['rgb'].float().squeeze(0).cpu().numpy()
|
101 |
+
vtx_colors = (vtx_colors * 255).astype(np.uint8)
|
102 |
+
mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors)
|
103 |
+
|
104 |
+
mesh_path = "awesome_mesh.obj"
|
105 |
+
mesh.export(mesh_path, 'obj')
|
106 |
+
|
107 |
+
return planes, mesh_path
|
108 |
+
|
109 |
+
# Generate video
|
110 |
+
if export_video:
|
111 |
+
render_cameras = self._default_render_cameras(batch_size=N).to(image.device)
|
112 |
+
|
113 |
+
frames = []
|
114 |
+
chunk_size = 1 # Adjust chunk size as needed
|
115 |
+
for i in range(0, render_cameras.shape[1], chunk_size):
|
116 |
+
frame_chunk = self.synthesizer(
|
117 |
+
planes,
|
118 |
+
render_cameras[:, i:i + chunk_size],
|
119 |
+
render_size,
|
120 |
+
render_size,
|
121 |
+
0,
|
122 |
+
0
|
123 |
+
)
|
124 |
+
frames.append(frame_chunk['images_rgb'])
|
125 |
+
|
126 |
+
frames = torch.cat(frames, dim=1)
|
127 |
+
frames = (frames.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
|
128 |
+
|
129 |
+
# Save video
|
130 |
+
video_path = "awesome_video.mp4"
|
131 |
+
imageio.mimwrite(video_path, frames, fps=fps)
|
132 |
+
|
133 |
+
return planes, video_path
|
134 |
+
|
135 |
+
return planes
|
136 |
|
137 |
# Copied from https://github.com/facebookresearch/vfusion3d/blob/main/lrm/cam_utils.py
|
138 |
# and https://github.com/facebookresearch/vfusion3d/blob/main/lrm/inferrer.py
|