kkwinds commited on
Commit
875ee73
·
1 Parent(s): 1ffa2b4

deploy mast3r sfm

Browse files
Files changed (2) hide show
  1. app.py +8 -1
  2. mast3r/demo.py +1 -2
app.py CHANGED
@@ -19,11 +19,15 @@ import mast3r.utils.path_to_dust3r # noqa
19
  from dust3r.demo import set_print_with_timestamp
20
 
21
  import matplotlib.pyplot as pl
 
 
 
22
  pl.ion()
23
 
24
  torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
25
 
26
- if __name__ == '__main__':
 
27
  parser = get_args_parser()
28
  args = parser.parse_args(['--model_name', 'MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric'])
29
  set_print_with_timestamp()
@@ -49,3 +53,6 @@ if __name__ == '__main__':
49
  os.makedirs(cache_path, exist_ok=True)
50
  main_demo(cache_path, model, args.retrieval_model, args.device, args.image_size, server_name, args.server_port,
51
  silent=args.silent, share=args.share, gradio_delete_cache=args.gradio_delete_cache)
 
 
 
 
19
  from dust3r.demo import set_print_with_timestamp
20
 
21
  import matplotlib.pyplot as pl
22
+
23
+ import spaces
24
+
25
  pl.ion()
26
 
27
  torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
28
 
29
+ @spaces.GPU(duration=300)
30
+ def main():
31
  parser = get_args_parser()
32
  args = parser.parse_args(['--model_name', 'MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric'])
33
  set_print_with_timestamp()
 
53
  os.makedirs(cache_path, exist_ok=True)
54
  main_demo(cache_path, model, args.retrieval_model, args.device, args.image_size, server_name, args.server_port,
55
  silent=args.silent, share=args.share, gradio_delete_cache=args.gradio_delete_cache)
56
+
57
+ if __name__ == '__main__':
58
+ main()
mast3r/demo.py CHANGED
@@ -30,7 +30,6 @@ from dust3r.demo import get_args_parser as dust3r_get_args_parser
30
 
31
  import matplotlib.pyplot as pl
32
 
33
- import spaces
34
 
35
  class SparseGAState:
36
  def __init__(self, sparse_ga, should_delete=False, cache_dir=None, outfile_name=None):
@@ -253,7 +252,7 @@ def set_scenegraph_options(inputfiles, win_cyclic, refid, scenegraph_type):
253
 
254
  return graph_opt, winsize, win_cyclic, refid
255
 
256
- @spaces.GPU(duration=300)
257
  def main_demo(tmpdirname, model, retrieval_model, device, image_size, server_name, server_port, silent=False,
258
  share=False, gradio_delete_cache=False):
259
  if not silent:
 
30
 
31
  import matplotlib.pyplot as pl
32
 
 
33
 
34
  class SparseGAState:
35
  def __init__(self, sparse_ga, should_delete=False, cache_dir=None, outfile_name=None):
 
252
 
253
  return graph_opt, winsize, win_cyclic, refid
254
 
255
+
256
  def main_demo(tmpdirname, model, retrieval_model, device, image_size, server_name, server_port, silent=False,
257
  share=False, gradio_delete_cache=False):
258
  if not silent: