Wenzheng Chang commited on
Commit
cd4da5b
·
1 Parent(s): d5d6d85

final version

Browse files
Files changed (2) hide show
  1. app.py +33 -42
  2. scripts/demo_gradio.py +40 -62
app.py CHANGED
@@ -39,7 +39,6 @@ from aether.utils.postprocess_utils import ( # noqa: E402
39
  )
40
  from aether.utils.visualize_utils import predictions_to_glb # noqa: E402
41
 
42
-
43
  def seed_all(seed: int = 0) -> None:
44
  """
45
  Set random seeds of all components.
@@ -73,33 +72,10 @@ pipeline = AetherV1PipelineCogVideoX(
73
  )
74
  pipeline.vae.enable_slicing()
75
  pipeline.vae.enable_tiling()
76
- # pipeline.to(device)
77
 
78
 
79
  def build_pipeline(device: torch.device) -> AetherV1PipelineCogVideoX:
80
  """Initialize the model pipeline."""
81
- # cogvideox_pretrained_model_name_or_path: str = "THUDM/CogVideoX-5b-I2V"
82
- # aether_pretrained_model_name_or_path: str = "AetherWorldModel/AetherV1"
83
- # pipeline = AetherV1PipelineCogVideoX(
84
- # tokenizer=AutoTokenizer.from_pretrained(
85
- # cogvideox_pretrained_model_name_or_path,
86
- # subfolder="tokenizer",
87
- # ),
88
- # text_encoder=T5EncoderModel.from_pretrained(
89
- # cogvideox_pretrained_model_name_or_path, subfolder="text_encoder"
90
- # ),
91
- # vae=AutoencoderKLCogVideoX.from_pretrained(
92
- # cogvideox_pretrained_model_name_or_path, subfolder="vae"
93
- # ),
94
- # scheduler=CogVideoXDPMScheduler.from_pretrained(
95
- # cogvideox_pretrained_model_name_or_path, subfolder="scheduler"
96
- # ),
97
- # transformer=CogVideoXTransformer3DModel.from_pretrained(
98
- # aether_pretrained_model_name_or_path, subfolder="transformer"
99
- # ),
100
- # )
101
- # pipeline.vae.enable_slicing()
102
- # pipeline.vae.enable_tiling()
103
  pipeline.to(device)
104
  return pipeline
105
 
@@ -346,21 +322,34 @@ def save_output_files(
346
  os.makedirs(output_dir, exist_ok=True)
347
 
348
  if pointmap is None and raymap is not None:
349
- # Generate pointmap from raymap and disparity
350
- smooth_camera = kwargs.get("smooth_camera", True)
351
- smooth_method = (
352
- kwargs.get("smooth_method", "kalman") if smooth_camera else "none"
353
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
 
355
- pointmap_dict = postprocess_pointmap(
356
- disparity,
357
- raymap,
358
- vae_downsample_scale=8,
359
- ray_o_scale_inv=0.1,
360
- smooth_camera=smooth_camera,
361
- smooth_method=smooth_method,
362
- )
363
- pointmap = pointmap_dict["pointmap"]
364
 
365
  if poses is None and raymap is not None:
366
  poses, _, _ = raymap_to_poses(raymap, ray_o_scale_inv=0.1)
@@ -432,7 +421,7 @@ def save_output_files(
432
  # flip Y axis and X axis of camera position
433
  flipped_poses[..., 1, 3] = -flipped_poses[..., 1, 3] # flip Y axis position
434
  flipped_poses[..., 0, 3] = -flipped_poses[..., 0, 3] # flip X axis position
435
-
436
  # use flipped point cloud and camera poses
437
  predictions = {
438
  "world_points": flipped_pointmap,
@@ -1512,7 +1501,7 @@ with gr.Blocks(
1512
  with gr.Column(scale=1):
1513
  fps = gr.Dropdown(
1514
  choices=[8, 10, 12, 15, 24],
1515
- value=12,
1516
  label="FPS",
1517
  info="Frames per second",
1518
  )
@@ -1816,8 +1805,9 @@ with gr.Blocks(
1816
 
1817
  run_button.click(
1818
  fn=lambda task_type,
1819
- video_file,
1820
  image_file,
 
1821
  goal_file,
1822
  height,
1823
  width,
@@ -1874,7 +1864,7 @@ with gr.Blocks(
1874
  ]
1875
  if task_type == "prediction"
1876
  else [
1877
- image_file,
1878
  goal_file,
1879
  height,
1880
  width,
@@ -1897,6 +1887,7 @@ with gr.Blocks(
1897
  task,
1898
  video_input,
1899
  image_input,
 
1900
  goal_input,
1901
  height,
1902
  width,
 
39
  )
40
  from aether.utils.visualize_utils import predictions_to_glb # noqa: E402
41
 
 
42
  def seed_all(seed: int = 0) -> None:
43
  """
44
  Set random seeds of all components.
 
72
  )
73
  pipeline.vae.enable_slicing()
74
  pipeline.vae.enable_tiling()
 
75
 
76
 
77
  def build_pipeline(device: torch.device) -> AetherV1PipelineCogVideoX:
78
  """Initialize the model pipeline."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  pipeline.to(device)
80
  return pipeline
81
 
 
322
  os.makedirs(output_dir, exist_ok=True)
323
 
324
  if pointmap is None and raymap is not None:
325
+ # # Generate pointmap from raymap and disparity
326
+ # smooth_camera = kwargs.get("smooth_camera", True)
327
+ # smooth_method = (
328
+ # kwargs.get("smooth_method", "kalman") if smooth_camera else "none"
329
+ # )
330
+
331
+ # pointmap_dict = postprocess_pointmap(
332
+ # disparity,
333
+ # raymap,
334
+ # vae_downsample_scale=8,
335
+ # ray_o_scale_inv=0.1,
336
+ # smooth_camera=smooth_camera,
337
+ # smooth_method=smooth_method,
338
+ # )
339
+ # pointmap = pointmap_dict["pointmap"]
340
+
341
+ window_result = AetherV1PipelineOutput(
342
+ rgb=rgb,
343
+ disparity=disparity,
344
+ raymap=raymap
345
+ )
346
+ window_results = [window_result]
347
+ window_indices = [0]
348
+ _, _, poses_from_blend, pointmap = blend_and_merge_window_results(window_results, window_indices, kwargs)
349
 
350
+ # Use poses from blend_and_merge_window_results if poses is None
351
+ if poses is None:
352
+ poses = poses_from_blend
 
 
 
 
 
 
353
 
354
  if poses is None and raymap is not None:
355
  poses, _, _ = raymap_to_poses(raymap, ray_o_scale_inv=0.1)
 
421
  # flip Y axis and X axis of camera position
422
  flipped_poses[..., 1, 3] = -flipped_poses[..., 1, 3] # flip Y axis position
423
  flipped_poses[..., 0, 3] = -flipped_poses[..., 0, 3] # flip X axis position
424
+
425
  # use flipped point cloud and camera poses
426
  predictions = {
427
  "world_points": flipped_pointmap,
 
1501
  with gr.Column(scale=1):
1502
  fps = gr.Dropdown(
1503
  choices=[8, 10, 12, 15, 24],
1504
+ value=24,
1505
  label="FPS",
1506
  info="Frames per second",
1507
  )
 
1805
 
1806
  run_button.click(
1807
  fn=lambda task_type,
1808
+ video_file,
1809
  image_file,
1810
+ image_input_planning,
1811
  goal_file,
1812
  height,
1813
  width,
 
1864
  ]
1865
  if task_type == "prediction"
1866
  else [
1867
+ image_input_planning,
1868
  goal_file,
1869
  height,
1870
  width,
 
1887
  task,
1888
  video_input,
1889
  image_input,
1890
+ image_input_planning,
1891
  goal_input,
1892
  height,
1893
  width,
scripts/demo_gradio.py CHANGED
@@ -17,8 +17,10 @@ from diffusers import (
17
  CogVideoXTransformer3DModel,
18
  )
19
  from transformers import AutoTokenizer, T5EncoderModel
20
- import spaces
21
 
 
 
22
 
23
  rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
24
 
@@ -39,7 +41,6 @@ from aether.utils.postprocess_utils import ( # noqa: E402
39
  )
40
  from aether.utils.visualize_utils import predictions_to_glb # noqa: E402
41
 
42
-
43
  def seed_all(seed: int = 0) -> None:
44
  """
45
  Set random seeds of all components.
@@ -73,33 +74,10 @@ pipeline = AetherV1PipelineCogVideoX(
73
  )
74
  pipeline.vae.enable_slicing()
75
  pipeline.vae.enable_tiling()
76
- # pipeline.to(device)
77
 
78
 
79
  def build_pipeline(device: torch.device) -> AetherV1PipelineCogVideoX:
80
  """Initialize the model pipeline."""
81
- # cogvideox_pretrained_model_name_or_path: str = "THUDM/CogVideoX-5b-I2V"
82
- # aether_pretrained_model_name_or_path: str = "AetherWorldModel/AetherV1"
83
- # pipeline = AetherV1PipelineCogVideoX(
84
- # tokenizer=AutoTokenizer.from_pretrained(
85
- # cogvideox_pretrained_model_name_or_path,
86
- # subfolder="tokenizer",
87
- # ),
88
- # text_encoder=T5EncoderModel.from_pretrained(
89
- # cogvideox_pretrained_model_name_or_path, subfolder="text_encoder"
90
- # ),
91
- # vae=AutoencoderKLCogVideoX.from_pretrained(
92
- # cogvideox_pretrained_model_name_or_path, subfolder="vae"
93
- # ),
94
- # scheduler=CogVideoXDPMScheduler.from_pretrained(
95
- # cogvideox_pretrained_model_name_or_path, subfolder="scheduler"
96
- # ),
97
- # transformer=CogVideoXTransformer3DModel.from_pretrained(
98
- # aether_pretrained_model_name_or_path, subfolder="transformer"
99
- # ),
100
- # )
101
- # pipeline.vae.enable_slicing()
102
- # pipeline.vae.enable_tiling()
103
  pipeline.to(device)
104
  return pipeline
105
 
@@ -346,21 +324,34 @@ def save_output_files(
346
  os.makedirs(output_dir, exist_ok=True)
347
 
348
  if pointmap is None and raymap is not None:
349
- # Generate pointmap from raymap and disparity
350
- smooth_camera = kwargs.get("smooth_camera", True)
351
- smooth_method = (
352
- kwargs.get("smooth_method", "kalman") if smooth_camera else "none"
353
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
 
355
- pointmap_dict = postprocess_pointmap(
356
- disparity,
357
- raymap,
358
- vae_downsample_scale=8,
359
- ray_o_scale_inv=0.1,
360
- smooth_camera=smooth_camera,
361
- smooth_method=smooth_method,
362
- )
363
- pointmap = pointmap_dict["pointmap"]
364
 
365
  if poses is None and raymap is not None:
366
  poses, _, _ = raymap_to_poses(raymap, ray_o_scale_inv=0.1)
@@ -432,7 +423,7 @@ def save_output_files(
432
  # flip Y axis and X axis of camera position
433
  flipped_poses[..., 1, 3] = -flipped_poses[..., 1, 3] # flip Y axis position
434
  flipped_poses[..., 0, 3] = -flipped_poses[..., 0, 3] # flip X axis position
435
-
436
  # use flipped point cloud and camera poses
437
  predictions = {
438
  "world_points": flipped_pointmap,
@@ -461,7 +452,7 @@ def save_output_files(
461
  return paths
462
 
463
 
464
- @spaces.GPU(duration=300)
465
  def process_reconstruction(
466
  video_file,
467
  height,
@@ -586,7 +577,7 @@ def process_reconstruction(
586
  return None, None, []
587
 
588
 
589
- @spaces.GPU(duration=300)
590
  def process_prediction(
591
  image_file,
592
  height,
@@ -718,7 +709,7 @@ def process_prediction(
718
  return None, None, []
719
 
720
 
721
- @spaces.GPU(duration=300)
722
  def process_planning(
723
  image_file,
724
  goal_file,
@@ -1377,21 +1368,6 @@ with gr.Blocks(
1377
 
1378
  with gr.Row(elem_classes=["main-interface"]):
1379
  with gr.Column(elem_classes=["input-column"]):
1380
- gpu_time_warning = gr.Markdown(
1381
- """
1382
- <div class="warning-box">
1383
- <strong>⚠️ Warning:</strong><br>
1384
- Due to HuggingFace Spaces ZERO GPU quota limitations, only short video reconstruction tasks (less than 100 frames) can be completed online.
1385
-
1386
- <strong>💻 Recommendation:</strong><br>
1387
- We strongly encourage you to deploy Aether locally for:
1388
- - Processing longer video reconstruction tasks
1389
- - Better performance and full access to prediction and planning tasks
1390
-
1391
- Visit our <a href="https://github.com/OpenRobotLab/Aether" target="_blank">GitHub repository</a> for local deployment instructions.
1392
- </div>
1393
- """,
1394
- )
1395
  with gr.Group(elem_classes=["task-selector"]):
1396
  task = gr.Radio(
1397
  ["reconstruction", "prediction", "planning"],
@@ -1512,7 +1488,7 @@ with gr.Blocks(
1512
  with gr.Column(scale=1):
1513
  fps = gr.Dropdown(
1514
  choices=[8, 10, 12, 15, 24],
1515
- value=12,
1516
  label="FPS",
1517
  info="Frames per second",
1518
  )
@@ -1816,8 +1792,9 @@ with gr.Blocks(
1816
 
1817
  run_button.click(
1818
  fn=lambda task_type,
1819
- video_file,
1820
  image_file,
 
1821
  goal_file,
1822
  height,
1823
  width,
@@ -1874,7 +1851,7 @@ with gr.Blocks(
1874
  ]
1875
  if task_type == "prediction"
1876
  else [
1877
- image_file,
1878
  goal_file,
1879
  height,
1880
  width,
@@ -1897,6 +1874,7 @@ with gr.Blocks(
1897
  task,
1898
  video_input,
1899
  image_input,
 
1900
  goal_input,
1901
  height,
1902
  width,
@@ -1940,4 +1918,4 @@ with gr.Blocks(
1940
 
1941
  if __name__ == "__main__":
1942
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
1943
- demo.queue(max_size=20).launch(show_error=True, share=True)
 
17
  CogVideoXTransformer3DModel,
18
  )
19
  from transformers import AutoTokenizer, T5EncoderModel
20
+ # import spaces
21
 
22
+ os.environ['GRADIO_TEMP_DIR'] = '.gradio_cache'
23
+ os.makedirs(os.environ['GRADIO_TEMP_DIR'], exist_ok=True)
24
 
25
  rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
26
 
 
41
  )
42
  from aether.utils.visualize_utils import predictions_to_glb # noqa: E402
43
 
 
44
  def seed_all(seed: int = 0) -> None:
45
  """
46
  Set random seeds of all components.
 
74
  )
75
  pipeline.vae.enable_slicing()
76
  pipeline.vae.enable_tiling()
 
77
 
78
 
79
  def build_pipeline(device: torch.device) -> AetherV1PipelineCogVideoX:
80
  """Initialize the model pipeline."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  pipeline.to(device)
82
  return pipeline
83
 
 
324
  os.makedirs(output_dir, exist_ok=True)
325
 
326
  if pointmap is None and raymap is not None:
327
+ # # Generate pointmap from raymap and disparity
328
+ # smooth_camera = kwargs.get("smooth_camera", True)
329
+ # smooth_method = (
330
+ # kwargs.get("smooth_method", "kalman") if smooth_camera else "none"
331
+ # )
332
+
333
+ # pointmap_dict = postprocess_pointmap(
334
+ # disparity,
335
+ # raymap,
336
+ # vae_downsample_scale=8,
337
+ # ray_o_scale_inv=0.1,
338
+ # smooth_camera=smooth_camera,
339
+ # smooth_method=smooth_method,
340
+ # )
341
+ # pointmap = pointmap_dict["pointmap"]
342
+
343
+ window_result = AetherV1PipelineOutput(
344
+ rgb=rgb,
345
+ disparity=disparity,
346
+ raymap=raymap
347
+ )
348
+ window_results = [window_result]
349
+ window_indices = [0]
350
+ _, _, poses_from_blend, pointmap = blend_and_merge_window_results(window_results, window_indices, kwargs)
351
 
352
+ # Use poses from blend_and_merge_window_results if poses is None
353
+ if poses is None:
354
+ poses = poses_from_blend
 
 
 
 
 
 
355
 
356
  if poses is None and raymap is not None:
357
  poses, _, _ = raymap_to_poses(raymap, ray_o_scale_inv=0.1)
 
423
  # flip Y axis and X axis of camera position
424
  flipped_poses[..., 1, 3] = -flipped_poses[..., 1, 3] # flip Y axis position
425
  flipped_poses[..., 0, 3] = -flipped_poses[..., 0, 3] # flip X axis position
426
+
427
  # use flipped point cloud and camera poses
428
  predictions = {
429
  "world_points": flipped_pointmap,
 
452
  return paths
453
 
454
 
455
+ # @spaces.GPU(duration=300)
456
  def process_reconstruction(
457
  video_file,
458
  height,
 
577
  return None, None, []
578
 
579
 
580
+ # @spaces.GPU(duration=300)
581
  def process_prediction(
582
  image_file,
583
  height,
 
709
  return None, None, []
710
 
711
 
712
+ # @spaces.GPU(duration=300)
713
  def process_planning(
714
  image_file,
715
  goal_file,
 
1368
 
1369
  with gr.Row(elem_classes=["main-interface"]):
1370
  with gr.Column(elem_classes=["input-column"]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1371
  with gr.Group(elem_classes=["task-selector"]):
1372
  task = gr.Radio(
1373
  ["reconstruction", "prediction", "planning"],
 
1488
  with gr.Column(scale=1):
1489
  fps = gr.Dropdown(
1490
  choices=[8, 10, 12, 15, 24],
1491
+ value=24,
1492
  label="FPS",
1493
  info="Frames per second",
1494
  )
 
1792
 
1793
  run_button.click(
1794
  fn=lambda task_type,
1795
+ video_file,
1796
  image_file,
1797
+ image_input_planning,
1798
  goal_file,
1799
  height,
1800
  width,
 
1851
  ]
1852
  if task_type == "prediction"
1853
  else [
1854
+ image_input_planning,
1855
  goal_file,
1856
  height,
1857
  width,
 
1874
  task,
1875
  video_input,
1876
  image_input,
1877
+ image_input_planning,
1878
  goal_input,
1879
  height,
1880
  width,
 
1918
 
1919
  if __name__ == "__main__":
1920
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
1921
+ demo.queue(max_size=20).launch(show_error=True, share=False, server_port=7860)