slothfulxtx commited on
Commit
3f62bf3
·
1 Parent(s): c96d544

update layout

Browse files
Files changed (1) hide show
  1. app.py +60 -50
app.py CHANGED
@@ -85,10 +85,6 @@ except Exception as e:
85
  # pipe.enable_xformers_memory_efficient_attention()
86
  pipe.enable_attention_slicing()
87
 
88
- mesh_seqs = []
89
- frame_seqs = []
90
- cur_mesh_idx = None
91
-
92
  def read_video_frames(video_path, process_length, max_res):
93
  print("==> processing video: ", video_path)
94
  vid = VideoReader(video_path, ctx=cpu(0))
@@ -131,12 +127,10 @@ def infer_geometry(
131
  decode_chunk_size: int,
132
  overlap: int,
133
  downsample_ratio: float = 1.0, # downsample pcd for visualization
134
- num_sample_frames: int =8, # downsample frames for visualization
135
  remove_edge: bool = True, # remove edge for visualization
136
  save_folder: str = os.path.join('workspace', 'GeometryCrafterApp'),
137
  ):
138
  try:
139
- global cur_mesh_idx, mesh_seqs, frame_seqs
140
  run_id = str(uuid.uuid4())
141
  set_seed(42)
142
  pipe.enable_xformers_memory_efficient_attention()
@@ -206,11 +200,10 @@ def infer_geometry(
206
  edge_mask = compute_edge_mask(point_maps[i, :, :, 2], 3)
207
  valid_masks[i] = valid_masks[i] & (~edge_mask)
208
 
209
- indices = np.linspace(0, len(point_maps)-1, num_sample_frames)
210
  indices = np.round(indices).astype(np.int32)
211
-
212
- mesh_seqs.clear()
213
- cur_mesh_idx = None
214
 
215
  for index in indices:
216
 
@@ -224,20 +217,17 @@ def infer_geometry(
224
  mesh_seqs.append(output_glb_path)
225
  frame_seqs.append(index)
226
 
227
- cur_mesh_idx = 0
228
 
229
  gc.collect()
230
  torch.cuda.empty_cache()
231
 
232
  return [
233
- gr.Model3D(value=mesh_seqs[cur_mesh_idx], label=f"Frame: {frame_seqs[cur_mesh_idx]}"),
 
234
  gr.Video(value=output_disp_path, label="Disparity", interactive=False),
235
  gr.DownloadButton("Download Npz File", value=output_npz_path, visible=True)
236
  ]
237
  except Exception as e:
238
- mesh_seqs.clear()
239
- frame_seqs.clear()
240
- cur_mesh_idx = None
241
  gc.collect()
242
  torch.cuda.empty_cache()
243
  raise gr.Error(str(e))
@@ -251,21 +241,6 @@ def infer_geometry(
251
  # gr.DownloadButton("Download Npz File", visible=False)
252
  # ]
253
 
254
- def goto_prev_frame():
255
- global cur_mesh_idx, mesh_seqs, frame_seqs
256
- if cur_mesh_idx is not None and len(mesh_seqs) > 0:
257
- if cur_mesh_idx > 0:
258
- cur_mesh_idx -= 1
259
- return gr.Model3D(value=mesh_seqs[cur_mesh_idx], label=f"Frame: {frame_seqs[cur_mesh_idx]}")
260
-
261
-
262
- def goto_next_frame():
263
- global cur_mesh_idx, mesh_seqs, frame_seqs
264
- if cur_mesh_idx is not None and len(mesh_seqs) > 0:
265
- if cur_mesh_idx < len(mesh_seqs)-1:
266
- cur_mesh_idx += 1
267
- return gr.Model3D(value=mesh_seqs[cur_mesh_idx], label=f"Frame: {frame_seqs[cur_mesh_idx]}")
268
-
269
  def download_file():
270
  return gr.DownloadButton(visible=False)
271
 
@@ -353,23 +328,58 @@ def build_demo():
353
  generate_btn = gr.Button("Generate")
354
 
355
  with gr.Column(scale=1):
356
- output_point_maps = gr.Model3D(
357
- label="Point Map",
 
 
 
 
 
 
 
 
358
  clear_color=[1.0, 1.0, 1.0, 1.0],
359
  # display_mode="solid"
360
  interactive=False
361
  )
362
- with gr.Row():
363
- prev_btn = gr.Button("Prev")
364
- next_btn = gr.Button("Next")
365
-
366
  with gr.Column(scale=1):
367
- output_disp_video = gr.Video(
368
- label="Disparity",
 
 
369
  interactive=False
370
  )
371
- download_btn = gr.DownloadButton("Download Npz File", visible=False)
372
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  gr.Examples(
374
  examples=examples,
375
  fn=infer_geometry,
@@ -383,7 +393,11 @@ def build_demo():
383
  decode_chunk_size,
384
  overlap,
385
  ],
386
- outputs=[output_point_maps, output_disp_video, download_btn],
 
 
 
 
387
  # cache_examples="lazy",
388
  )
389
  gr.Markdown(
@@ -412,17 +426,13 @@ def build_demo():
412
  decode_chunk_size,
413
  overlap,
414
  ],
415
- outputs=[output_point_maps, output_disp_video, download_btn],
 
 
 
 
416
  )
417
 
418
- prev_btn.click(
419
- fn=goto_prev_frame,
420
- outputs=output_point_maps,
421
- )
422
- next_btn.click(
423
- fn=goto_next_frame,
424
- outputs=output_point_maps,
425
- )
426
  download_btn.click(
427
  fn=download_file,
428
  outputs=download_btn
 
85
  # pipe.enable_xformers_memory_efficient_attention()
86
  pipe.enable_attention_slicing()
87
 
 
 
 
 
88
  def read_video_frames(video_path, process_length, max_res):
89
  print("==> processing video: ", video_path)
90
  vid = VideoReader(video_path, ctx=cpu(0))
 
127
  decode_chunk_size: int,
128
  overlap: int,
129
  downsample_ratio: float = 1.0, # downsample pcd for visualization
 
130
  remove_edge: bool = True, # remove edge for visualization
131
  save_folder: str = os.path.join('workspace', 'GeometryCrafterApp'),
132
  ):
133
  try:
 
134
  run_id = str(uuid.uuid4())
135
  set_seed(42)
136
  pipe.enable_xformers_memory_efficient_attention()
 
200
  edge_mask = compute_edge_mask(point_maps[i, :, :, 2], 3)
201
  valid_masks[i] = valid_masks[i] & (~edge_mask)
202
 
203
+ indices = np.linspace(0, len(point_maps)-1, 6)
204
  indices = np.round(indices).astype(np.int32)
205
+
206
+ mesh_seqs, frame_seqs = [], []
 
207
 
208
  for index in indices:
209
 
 
217
  mesh_seqs.append(output_glb_path)
218
  frame_seqs.append(index)
219
 
 
220
 
221
  gc.collect()
222
  torch.cuda.empty_cache()
223
 
224
  return [
225
+ gr.Model3D(value=mesh_seqs[idx], label=f"Frame: {frame_seqs[idx]}") for idx in range(len(frame_seqs))
226
+ ] + [
227
  gr.Video(value=output_disp_path, label="Disparity", interactive=False),
228
  gr.DownloadButton("Download Npz File", value=output_npz_path, visible=True)
229
  ]
230
  except Exception as e:
 
 
 
231
  gc.collect()
232
  torch.cuda.empty_cache()
233
  raise gr.Error(str(e))
 
241
  # gr.DownloadButton("Download Npz File", visible=False)
242
  # ]
243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  def download_file():
245
  return gr.DownloadButton(visible=False)
246
 
 
328
  generate_btn = gr.Button("Generate")
329
 
330
  with gr.Column(scale=1):
331
+ output_disp_video = gr.Video(
332
+ label="Disparity",
333
+ interactive=False
334
+ )
335
+ download_btn = gr.DownloadButton("Download Npz File", visible=False)
336
+
337
+ with gr.Row(equal_height=True):
338
+ with gr.Column(scale=1):
339
+ output_point_map0 = gr.Model3D(
340
+ label="Point Map 0",
341
  clear_color=[1.0, 1.0, 1.0, 1.0],
342
  # display_mode="solid"
343
  interactive=False
344
  )
 
 
 
 
345
  with gr.Column(scale=1):
346
+ output_point_map1 = gr.Model3D(
347
+ label="Point Map 1",
348
+ clear_color=[1.0, 1.0, 1.0, 1.0],
349
+ # display_mode="solid"
350
  interactive=False
351
  )
352
+ with gr.Column(scale=1):
353
+ output_point_map2 = gr.Model3D(
354
+ label="Point Map 2",
355
+ clear_color=[1.0, 1.0, 1.0, 1.0],
356
+ # display_mode="solid"
357
+ interactive=False
358
+ )
359
+
360
+ with gr.Row(equal_height=True):
361
+ with gr.Column(scale=1):
362
+ output_point_map3 = gr.Model3D(
363
+ label="Point Map 3",
364
+ clear_color=[1.0, 1.0, 1.0, 1.0],
365
+ # display_mode="solid"
366
+ interactive=False
367
+ )
368
+ with gr.Column(scale=1):
369
+ output_point_map4 = gr.Model3D(
370
+ label="Point Map 4",
371
+ clear_color=[1.0, 1.0, 1.0, 1.0],
372
+ # display_mode="solid"
373
+ interactive=False
374
+ )
375
+ with gr.Column(scale=1):
376
+ output_point_map5 = gr.Model3D(
377
+ label="Point Map 5",
378
+ clear_color=[1.0, 1.0, 1.0, 1.0],
379
+ # display_mode="solid"
380
+ interactive=False
381
+ )
382
+
383
  gr.Examples(
384
  examples=examples,
385
  fn=infer_geometry,
 
393
  decode_chunk_size,
394
  overlap,
395
  ],
396
+ outputs=[
397
+ output_point_map0, output_point_map1, output_point_map2,
398
+ output_point_map3, output_point_map4, output_point_map5,
399
+ output_disp_video, download_btn
400
+ ],
401
  # cache_examples="lazy",
402
  )
403
  gr.Markdown(
 
426
  decode_chunk_size,
427
  overlap,
428
  ],
429
+ outputs=[
430
+ output_point_map0, output_point_map1, output_point_map2,
431
+ output_point_map3, output_point_map4, output_point_map5,
432
+ output_disp_video, download_btn
433
+ ],
434
  )
435
 
 
 
 
 
 
 
 
 
436
  download_btn.click(
437
  fn=download_file,
438
  outputs=download_btn