hujiecpp commited on
Commit
735791c
·
1 Parent(s): ae113d0

init project

Browse files
Files changed (3) hide show
  1. app.py +1 -3
  2. modules/pe3r/demo.py +4 -1
  3. modules/pe3r/models.py +5 -1
app.py CHANGED
@@ -21,8 +21,6 @@ import torch
21
  # builtin_print(*args, **kwargs)
22
  # builtins.print = print_with_timestamp
23
 
24
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
25
-
26
  def get_args_parser():
27
  parser = argparse.ArgumentParser()
28
  parser_url = parser.add_mutually_exclusive_group()
@@ -53,7 +51,7 @@ if __name__ == '__main__':
53
  else:
54
  server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
55
 
56
- pe3r = Models(device=device)
57
 
58
  with tempfile.TemporaryDirectory(suffix='pe3r_gradio_demo') as tmpdirname:
59
  if not args.silent:
 
21
  # builtin_print(*args, **kwargs)
22
  # builtins.print = print_with_timestamp
23
 
 
 
24
  def get_args_parser():
25
  parser = argparse.ArgumentParser()
26
  parser_url = parser.add_mutually_exclusive_group()
 
51
  else:
52
  server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
53
 
54
+ pe3r = Models()
55
 
56
  with tempfile.TemporaryDirectory(suffix='pe3r_gradio_demo') as tmpdirname:
57
  if not args.silent:
modules/pe3r/demo.py CHANGED
@@ -10,6 +10,7 @@ import math
10
  import gradio
11
  import os
12
  import torch
 
13
  import numpy as np
14
  import functools
15
  import trimesh
@@ -548,7 +549,9 @@ def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type):
548
  return winsize, refid
549
 
550
 
551
- def main_demo(tmpdirname, pe3r, device, server_name, server_port, silent=False):
 
 
552
  # scene, outfile, imgs = get_reconstructed_scene(
553
  # outdir=tmpdirname, pe3r=pe3r, device=device, silent=silent,
554
  # filelist=['/home/hujie/pe3r/datasets/mipnerf360_ov/bonsai/black_chair/images/DSCF5590.png',
 
10
  import gradio
11
  import os
12
  import torch
13
+ import spaces
14
  import numpy as np
15
  import functools
16
  import trimesh
 
549
  return winsize, refid
550
 
551
 
552
+ @spaces.GPU
553
+ def main_demo(tmpdirname, pe3r, server_name, server_port, silent=False):
554
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
555
  # scene, outfile, imgs = get_reconstructed_scene(
556
  # outdir=tmpdirname, pe3r=pe3r, device=device, silent=silent,
557
  # filelist=['/home/hujie/pe3r/datasets/mipnerf360_ov/bonsai/black_chair/images/DSCF5590.png',
modules/pe3r/models.py CHANGED
@@ -10,9 +10,13 @@ from modules.mobilesamv2.promt_mobilesamv2 import ObjectAwareModel
10
  from modules.mobilesamv2 import sam_model_registry
11
 
12
  from sam2.sam2_video_predictor import SAM2VideoPredictor
 
 
13
 
14
  class Models:
15
- def __init__(self, device):
 
 
16
  # -- mast3r --
17
  # MAST3R_CKP = './checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth'
18
  MAST3R_CKP = 'naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric'
 
10
  from modules.mobilesamv2 import sam_model_registry
11
 
12
  from sam2.sam2_video_predictor import SAM2VideoPredictor
13
+ import spaces
14
+ import torch
15
 
16
  class Models:
17
+ @spaces.GPU
18
+ def __init__(self, device='cpu'):
19
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
20
  # -- mast3r --
21
  # MAST3R_CKP = './checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth'
22
  MAST3R_CKP = 'naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric'