Yiwen-ntu commited on
Commit
acc6365
·
verified ·
1 Parent(s): a812784

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +270 -0
app.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import subprocess
3
+ # Install flash attention, skipping CUDA build if necessary
4
+ subprocess.run(
5
+ "pip install flash-attn --no-build-isolation",
6
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
7
+ shell=True,
8
+ )
9
+ import os
10
+ import torch
11
+ import trimesh
12
+ from accelerate.utils import set_seed
13
+ from accelerate import Accelerator
14
+ import numpy as np
15
+ import gradio as gr
16
+ from main import get_args, load_model
17
+ from mesh_to_pc import process_mesh_to_pc
18
+ import time
19
+ import matplotlib.pyplot as plt
20
+ from mpl_toolkits.mplot3d.art3d import Poly3DCollection
21
+ from PIL import Image
22
+ import io
23
+
24
+ args = get_args()
25
+ model = load_model(args)
26
+
27
+ device = torch.device('cuda')
28
+ accelerator = Accelerator(
29
+ mixed_precision="fp16",
30
+ )
31
+ model = accelerator.prepare(model)
32
+ model.eval()
33
+ print("Model loaded to device")
34
+
35
+ def wireframe_render(mesh):
36
+ views = [
37
+ (90, 20), (270, 20)
38
+ ]
39
+ mesh.vertices = mesh.vertices[:, [0, 2, 1]]
40
+
41
+ bounding_box = mesh.bounds
42
+ center = mesh.centroid
43
+ scale = np.ptp(bounding_box, axis=0).max()
44
+
45
+ fig = plt.figure(figsize=(10, 10))
46
+
47
+ # Function to render and return each view as an image
48
+ def render_view(mesh, azimuth, elevation):
49
+ ax = fig.add_subplot(111, projection='3d')
50
+ ax.set_axis_off()
51
+
52
+ # Extract vertices and faces for plotting
53
+ vertices = mesh.vertices
54
+ faces = mesh.faces
55
+
56
+ # Plot faces
57
+ ax.add_collection3d(Poly3DCollection(
58
+ vertices[faces],
59
+ facecolors=(0.8, 0.5, 0.2, 1.0), # Brownish yellow
60
+ edgecolors='k',
61
+ linewidths=0.5,
62
+ ))
63
+
64
+ # Set limits and center the view on the object
65
+ ax.set_xlim(center[0] - scale / 2, center[0] + scale / 2)
66
+ ax.set_ylim(center[1] - scale / 2, center[1] + scale / 2)
67
+ ax.set_zlim(center[2] - scale / 2, center[2] + scale / 2)
68
+
69
+ # Set view angle
70
+ ax.view_init(elev=elevation, azim=azimuth)
71
+
72
+ # Save the figure to a buffer
73
+ buf = io.BytesIO()
74
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=300)
75
+ plt.clf()
76
+ buf.seek(0)
77
+
78
+ return Image.open(buf)
79
+
80
+ # Render each view and store in a list
81
+ images = [render_view(mesh, az, el) for az, el in views]
82
+
83
+ # Combine images horizontally
84
+ widths, heights = zip(*(i.size for i in images))
85
+ total_width = sum(widths)
86
+ max_height = max(heights)
87
+
88
+ combined_image = Image.new('RGBA', (total_width, max_height))
89
+
90
+ x_offset = 0
91
+ for img in images:
92
+ combined_image.paste(img, (x_offset, 0))
93
+ x_offset += img.width
94
+
95
+ # Save the combined image
96
+ save_path = f"combined_mesh_view_{int(time.time())}.png"
97
+ combined_image.save(save_path)
98
+
99
+ plt.close(fig)
100
+ return save_path
101
+
102
+ @spaces.GPU(duration=120)
103
+ def do_inference(input_3d, sample_seed=0, do_sampling=False, do_marching_cubes=False):
104
+ set_seed(sample_seed)
105
+ print("Seed value:", sample_seed)
106
+
107
+ input_mesh = trimesh.load(input_3d)
108
+ pc_list, mesh_list = process_mesh_to_pc([input_mesh], marching_cubes = do_marching_cubes)
109
+ mesh = mesh_list[0]
110
+ mesh.merge_vertices()
111
+ mesh.update_faces(mesh.unique_faces())
112
+ mesh.fix_normals()
113
+ if mesh.visual.vertex_colors is not None:
114
+ orange_color = np.array([255, 165, 0, 255], dtype=np.uint8)
115
+
116
+ mesh.visual.vertex_colors = np.tile(orange_color, (mesh.vertices.shape[0], 1))
117
+ else:
118
+ orange_color = np.array([255, 165, 0, 255], dtype=np.uint8)
119
+ mesh.visual.vertex_colors = np.tile(orange_color, (mesh.vertices.shape[0], 1))
120
+ input_save_name = f"processed_input_{int(time.time())}.obj"
121
+ mesh.export(input_save_name)
122
+ input_render_res = wireframe_render(mesh)
123
+
124
+ pc_normal = pc_list[0] # 4096, 6
125
+ pc_coor = pc_normal[:, :3]
126
+ normals = pc_normal[:, 3:]
127
+
128
+ bounds = np.array([pc_coor.min(axis=0), pc_coor.max(axis=0)])
129
+ pc_coor = pc_coor - (bounds[0] + bounds[1])[None, :] / 2
130
+ pc_coor = pc_coor / np.abs(pc_coor).max() * 0.9995
131
+ assert (np.linalg.norm(normals, axis=-1) > 0.99).all(), "normals should be unit vectors, something wrong"
132
+ normalized_pc_normal = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16)
133
+
134
+ input = torch.tensor(normalized_pc_normal, dtype=torch.float16, device=device)[None]
135
+ print("Data loaded")
136
+
137
+ # with accelerator.autocast():
138
+ with accelerator.autocast():
139
+ outputs = model(input, do_sampling)
140
+ print("Model inference done")
141
+ recon_mesh = outputs[0]
142
+
143
+ recon_mesh = recon_mesh[~torch.isnan(recon_mesh[:, 0, 0])] # nvalid_face x 3 x 3
144
+ vertices = recon_mesh.reshape(-1, 3).cpu()
145
+ vertices_index = np.arange(len(vertices)) # 0, 1, ..., 3 x face
146
+ triangles = vertices_index.reshape(-1, 3)
147
+
148
+ artist_mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, force="mesh",
149
+ merge_primitives=True)
150
+ artist_mesh.merge_vertices()
151
+ artist_mesh.update_faces(artist_mesh.unique_faces())
152
+ artist_mesh.fix_normals()
153
+
154
+ if artist_mesh.visual.vertex_colors is not None:
155
+ orange_color = np.array([255, 165, 0, 255], dtype=np.uint8)
156
+
157
+ artist_mesh.visual.vertex_colors = np.tile(orange_color, (artist_mesh.vertices.shape[0], 1))
158
+ else:
159
+ orange_color = np.array([255, 165, 0, 255], dtype=np.uint8)
160
+ artist_mesh.visual.vertex_colors = np.tile(orange_color, (artist_mesh.vertices.shape[0], 1))
161
+
162
+ num_faces = len(artist_mesh.faces)
163
+
164
+ brown_color = np.array([165, 42, 42, 255], dtype=np.uint8)
165
+ face_colors = np.tile(brown_color, (num_faces, 1))
166
+
167
+ artist_mesh.visual.face_colors = face_colors
168
+ # add time stamp to avoid cache
169
+ save_name = f"output_{int(time.time())}.obj"
170
+ artist_mesh.export(save_name)
171
+ output_render = wireframe_render(artist_mesh)
172
+ return input_save_name, input_render_res, save_name, output_render
173
+
174
+
175
+ _HEADER_ = '''
176
+ <h2><b>Official 🤗 Gradio Demo</b></h2><h2><a href='https://github.com/buaacyw/MeshAnything' target='_blank'><b>MeshAnything: Artist-Created Mesh Generation with Autoregressive Transformers</b></a></h2>
177
+
178
+ **MeshAnything** converts any 3D representation into meshes created by human artists, i.e., Artist-Created Meshes (AMs).
179
+
180
+ Code: <a href='https://github.com/buaacyw/MeshAnything' target='_blank'>GitHub</a>. Arxiv Paper: <a href='https://gaussianeditor.github.io/' target='_blank'>ArXiv</a>.
181
+
182
+ ❗️❗️❗️**Important Notes:**
183
+ - Gradio doesn't support interactive wireframe rendering currently. For interactive mesh visualization, please use download the obj file and open it with MeshLab or https://3dviewer.net/.
184
+ - The input mesh will be normalized to a unit bounding box. The up vector of the input mesh should be +Y for better results. Click **Preprocess with Marching Cubes** if the input mesh is a manually created mesh.
185
+ - Limited by computational resources, MeshAnything is trained on meshes with fewer than 800 faces and cannot generate meshes with more than 800 faces. The shape of the input mesh should be sharp enough; otherwise, it will be challenging to represent it with only 800 faces. Thus, feed-forward image-to-3D methods may often produce bad results due to insufficient shape quality.
186
+ - For point cloud input, please refer to our github repo <a href='https://github.com/buaacyw/MeshAnything' target='_blank'>GitHub</a>.
187
+ '''
188
+
189
+
190
+ _CITE_ = r"""
191
+ If MeshAnything is helpful, please help to ⭐ the <a href='https://github.com/buaacyw/MeshAnything' target='_blank'>Github Repo</a>. Thanks!
192
+ ---
193
+ 📋 **License**
194
+
195
+ S-Lab-1.0 LICENSE. Please refer to the [LICENSE file](https://github.com/buaacyw/GaussianEditor/blob/master/LICENSE.txt) for details.
196
+
197
+ 📧 **Contact**
198
+
199
+ If you have any questions, feel free to open a discussion or contact us at <b>[email protected]</b>.
200
+
201
+ """
202
+ output_model_obj = gr.Model3D(
203
+ label="Processed Input Mesh (OBJ Format)",
204
+ clear_color=[1, 1, 1, 1],
205
+ )
206
+ preprocess_model_obj = gr.Model3D(
207
+ label="Generated Mesh (OBJ Format)",
208
+ clear_color=[1, 1, 1, 1],
209
+ )
210
+ input_image_render = gr.Image(
211
+ label="Wireframe Render of Processed Input Mesh",
212
+ )
213
+ output_image_render = gr.Image(
214
+ label="Wireframe Render of Generated Mesh",
215
+ )
216
+ with (gr.Blocks() as demo):
217
+ gr.Markdown(_HEADER_)
218
+ with gr.Row(variant="panel"):
219
+ with gr.Column():
220
+ with gr.Row():
221
+ input_3d = gr.Model3D(
222
+ label="Input Mesh",
223
+ clear_color=[1,1,1,1],
224
+ )
225
+
226
+ with gr.Row():
227
+ with gr.Group():
228
+ do_marching_cubes = gr.Checkbox(label="Preprocess with Marching Cubes", value=False)
229
+ do_sampling = gr.Checkbox(label="Random Sampling", value=False)
230
+ sample_seed = gr.Number(value=0, label="Seed Value", precision=0)
231
+
232
+ with gr.Row():
233
+ submit = gr.Button("Generate", elem_id="generate", variant="primary")
234
+
235
+ with gr.Row(variant="panel"):
236
+ mesh_examples = gr.Examples(
237
+ examples=[
238
+ os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
239
+ ],
240
+ inputs=input_3d,
241
+ outputs=[preprocess_model_obj, input_image_render, output_model_obj, output_image_render],
242
+ fn=do_inference,
243
+ cache_examples = "lazy",
244
+ examples_per_page=16
245
+ )
246
+ with gr.Column():
247
+ with gr.Row():
248
+ input_image_render.render()
249
+ with gr.Row():
250
+ with gr.Tab("OBJ"):
251
+ preprocess_model_obj.render()
252
+ with gr.Row():
253
+ output_image_render.render()
254
+ with gr.Row():
255
+ with gr.Tab("OBJ"):
256
+ output_model_obj.render()
257
+ with gr.Row():
258
+ gr.Markdown('''Try click random sampling and different <b>Seed Value</b> if the result is unsatisfying''')
259
+
260
+ gr.Markdown(_CITE_)
261
+
262
+ mv_images = gr.State()
263
+
264
+ submit.click(
265
+ fn=do_inference,
266
+ inputs=[input_3d, sample_seed, do_sampling, do_marching_cubes],
267
+ outputs=[preprocess_model_obj, input_image_render, output_model_obj, output_image_render],
268
+ )
269
+
270
+ demo.launch(share=True)