ashawkey commited on
Commit
8ba0984
·
1 Parent(s): 6def56d

default to not decimate

Browse files
Files changed (2) hide show
  1. app.py +15 -9
  2. vae/utils.py +3 -3
app.py CHANGED
@@ -62,7 +62,7 @@ model.load_state_dict(ckpt_dict, strict=True)
62
 
63
  # process function
64
  @spaces.GPU(duration=120)
65
- def process(input_image, input_num_steps=30, input_cfg_scale=7.5, grid_res=384, seed=42, randomize_seed=True):
66
 
67
  # seed
68
  if randomize_seed:
@@ -89,7 +89,7 @@ def process(input_image, input_num_steps=30, input_cfg_scale=7.5, grid_res=384,
89
  data = {"cond_images": image_tensor}
90
 
91
  with torch.inference_mode():
92
- results = model(data, num_steps=input_num_steps, cfg_scale=input_cfg_scale)
93
 
94
  latent = results["latent"]
95
 
@@ -102,16 +102,19 @@ def process(input_image, input_num_steps=30, input_cfg_scale=7.5, grid_res=384,
102
  results_part0 = model.vae(data_part0, resolution=grid_res)
103
  results_part1 = model.vae(data_part1, resolution=grid_res)
104
 
 
 
 
105
  vertices, faces = results_part0["meshes"][0]
106
  mesh_part0 = trimesh.Trimesh(vertices, faces)
107
  mesh_part0.vertices = mesh_part0.vertices @ TRIMESH_GLB_EXPORT.T
108
- mesh_part0 = postprocess_mesh(mesh_part0, 5e4)
109
  parts = mesh_part0.split(only_watertight=False)
110
 
111
  vertices, faces = results_part1["meshes"][0]
112
  mesh_part1 = trimesh.Trimesh(vertices, faces)
113
  mesh_part1.vertices = mesh_part1.vertices @ TRIMESH_GLB_EXPORT.T
114
- mesh_part1 = postprocess_mesh(mesh_part1, 5e4)
115
  parts.extend(mesh_part1.split(only_watertight=False))
116
 
117
  # split connected components and assign different colors
@@ -147,23 +150,26 @@ with block:
147
  gr.Markdown(_DESCRIPTION)
148
 
149
  with gr.Row():
150
- with gr.Column(scale=2):
151
  # input image
152
  input_image = gr.Image(label="Image", type='pil')
153
  # inference steps
154
- input_num_steps = gr.Slider(label="Inference steps", minimum=1, maximum=100, step=1, value=30)
155
  # cfg scale
156
- input_cfg_scale = gr.Slider(label="CFG scale", minimum=2, maximum=10, step=0.1, value=7.5)
157
  # grid resolution
158
  input_grid_res = gr.Slider(label="Grid resolution", minimum=256, maximum=512, step=1, value=384)
159
  # random seed
160
  seed = gr.Slider(label="Random seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
161
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
 
 
162
  # gen button
163
  button_gen = gr.Button("Generate")
164
 
165
 
166
- with gr.Column(scale=4):
167
  with gr.Tab("3D Model"):
168
  # glb file
169
  output_model = gr.Model3D(label="Geometry", height=512)
@@ -190,6 +196,6 @@ with block:
190
  cache_examples=False,
191
  )
192
 
193
- button_gen.click(process, inputs=[input_image, input_num_steps, input_cfg_scale, input_grid_res, seed, randomize_seed], outputs=[seed, output_image, output_model])
194
 
195
  block.launch()
 
62
 
63
  # process function
64
  @spaces.GPU(duration=120)
65
+ def process(input_image, num_steps=30, cfg_scale=7.5, grid_res=384, seed=42, randomize_seed=True, simplify_mesh=False, target_num_faces=100000):
66
 
67
  # seed
68
  if randomize_seed:
 
89
  data = {"cond_images": image_tensor}
90
 
91
  with torch.inference_mode():
92
+ results = model(data, num_steps=num_steps, cfg_scale=cfg_scale)
93
 
94
  latent = results["latent"]
95
 
 
102
  results_part0 = model.vae(data_part0, resolution=grid_res)
103
  results_part1 = model.vae(data_part1, resolution=grid_res)
104
 
105
+ if not simplify_mesh:
106
+ target_num_faces = -1
107
+
108
  vertices, faces = results_part0["meshes"][0]
109
  mesh_part0 = trimesh.Trimesh(vertices, faces)
110
  mesh_part0.vertices = mesh_part0.vertices @ TRIMESH_GLB_EXPORT.T
111
+ mesh_part0 = postprocess_mesh(mesh_part0, target_num_faces)
112
  parts = mesh_part0.split(only_watertight=False)
113
 
114
  vertices, faces = results_part1["meshes"][0]
115
  mesh_part1 = trimesh.Trimesh(vertices, faces)
116
  mesh_part1.vertices = mesh_part1.vertices @ TRIMESH_GLB_EXPORT.T
117
+ mesh_part1 = postprocess_mesh(mesh_part1, target_num_faces)
118
  parts.extend(mesh_part1.split(only_watertight=False))
119
 
120
  # split connected components and assign different colors
 
150
  gr.Markdown(_DESCRIPTION)
151
 
152
  with gr.Row():
153
+ with gr.Column(scale=4):
154
  # input image
155
  input_image = gr.Image(label="Image", type='pil')
156
  # inference steps
157
+ num_steps = gr.Slider(label="Inference steps", minimum=1, maximum=100, step=1, value=30)
158
  # cfg scale
159
+ cfg_scale = gr.Slider(label="CFG scale", minimum=2, maximum=10, step=0.1, value=7.0)
160
  # grid resolution
161
  input_grid_res = gr.Slider(label="Grid resolution", minimum=256, maximum=512, step=1, value=384)
162
  # random seed
163
  seed = gr.Slider(label="Random seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
164
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
165
+ # simplify mesh
166
+ simplify_mesh = gr.Checkbox(label="Simplify mesh", value=False)
167
+ target_num_faces = gr.Slider(label="Face number", minimum=10000, maximum=1000000, step=1000, value=100000)
168
  # gen button
169
  button_gen = gr.Button("Generate")
170
 
171
 
172
+ with gr.Column(scale=8):
173
  with gr.Tab("3D Model"):
174
  # glb file
175
  output_model = gr.Model3D(label="Geometry", height=512)
 
196
  cache_examples=False,
197
  )
198
 
199
+ button_gen.click(process, inputs=[input_image, num_steps, cfg_scale, input_grid_res, seed, randomize_seed, simplify_mesh, target_num_faces], outputs=[seed, output_image, output_model])
200
 
201
  block.launch()
vae/utils.py CHANGED
@@ -287,10 +287,10 @@ def postprocess_mesh(mesh: trimesh.Trimesh, decimate_target=100000):
287
 
288
  if vertices.shape[0] > 0 and triangles.shape[0] > 0:
289
  vertices, triangles = clean_mesh(vertices, triangles, remesh=False, min_f=25, min_d=5)
290
- if triangles.shape[0] > decimate_target:
291
  vertices, triangles = decimate_mesh(vertices, triangles, decimate_target, optimalplacement=False)
292
- if vertices.shape[0] > 0 and triangles.shape[0] > 0:
293
- vertices, triangles = clean_mesh(vertices, triangles, remesh=False, min_f=25, min_d=5)
294
 
295
  mesh.vertices = vertices
296
  mesh.faces = triangles
 
287
 
288
  if vertices.shape[0] > 0 and triangles.shape[0] > 0:
289
  vertices, triangles = clean_mesh(vertices, triangles, remesh=False, min_f=25, min_d=5)
290
+ if decimate_target > 0 and triangles.shape[0] > decimate_target:
291
  vertices, triangles = decimate_mesh(vertices, triangles, decimate_target, optimalplacement=False)
292
+ if vertices.shape[0] > 0 and triangles.shape[0] > 0:
293
+ vertices, triangles = clean_mesh(vertices, triangles, remesh=False, min_f=25, min_d=5)
294
 
295
  mesh.vertices = vertices
296
  mesh.faces = triangles