kkwinds commited on
Commit
d027e4d
·
1 Parent(s): 5bedd14

deploy mast3r sfm

Browse files
Files changed (4) hide show
  1. app.py +49 -5
  2. asmk +1 -0
  3. demo.py +0 -51
  4. requirements.txt +21 -1
app.py CHANGED
@@ -1,7 +1,51 @@
1
- import gradio as gr
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
3
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4
+ #
5
+ # --------------------------------------------------------
6
+ # gradio demo executable
7
+ # --------------------------------------------------------
8
+ import os
9
+ import torch
10
+ import tempfile
11
+ from contextlib import nullcontext
12
 
13
+ from mast3r.demo import get_args_parser, main_demo
 
14
 
15
+ from mast3r.model import AsymmetricMASt3R
16
+ from mast3r.utils.misc import hash_md5
17
+
18
+ import mast3r.utils.path_to_dust3r # noqa
19
+ from dust3r.demo import set_print_with_timestamp
20
+
21
+ import matplotlib.pyplot as pl
22
+ pl.ion()
23
+
24
+ torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
25
+
26
+ if __name__ == '__main__':
27
+ parser = get_args_parser()
28
+ args = parser.parse_args()
29
+ set_print_with_timestamp()
30
+
31
+ if args.server_name is not None:
32
+ server_name = args.server_name
33
+ else:
34
+ server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
35
+
36
+ if args.weights is not None:
37
+ weights_path = args.weights
38
+ else:
39
+ weights_path = "naver/" + args.model_name
40
+
41
+ model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device)
42
+ chkpt_tag = hash_md5(weights_path)
43
+
44
+ def get_context(tmp_dir):
45
+ return tempfile.TemporaryDirectory(suffix='_mast3r_gradio_demo') if tmp_dir is None \
46
+ else nullcontext(tmp_dir)
47
+ with get_context(args.tmp_dir) as tmpdirname:
48
+ cache_path = os.path.join(tmpdirname, chkpt_tag)
49
+ os.makedirs(cache_path, exist_ok=True)
50
+ main_demo(cache_path, model, args.retrieval_model, args.device, args.image_size, server_name, args.server_port,
51
+ silent=args.silent, share=args.share, gradio_delete_cache=args.gradio_delete_cache)
asmk ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 2a96d9c03a841dffdfddabc699a20512dcd09363
demo.py DELETED
@@ -1,51 +0,0 @@
1
- #!/usr/bin/env python3
2
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
3
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4
- #
5
- # --------------------------------------------------------
6
- # gradio demo executable
7
- # --------------------------------------------------------
8
- import os
9
- import torch
10
- import tempfile
11
- from contextlib import nullcontext
12
-
13
- from mast3r.demo import get_args_parser, main_demo
14
-
15
- from mast3r.model import AsymmetricMASt3R
16
- from mast3r.utils.misc import hash_md5
17
-
18
- import mast3r.utils.path_to_dust3r # noqa
19
- from dust3r.demo import set_print_with_timestamp
20
-
21
- import matplotlib.pyplot as pl
22
- pl.ion()
23
-
24
- torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
25
-
26
- if __name__ == '__main__':
27
- parser = get_args_parser()
28
- args = parser.parse_args()
29
- set_print_with_timestamp()
30
-
31
- if args.server_name is not None:
32
- server_name = args.server_name
33
- else:
34
- server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
35
-
36
- if args.weights is not None:
37
- weights_path = args.weights
38
- else:
39
- weights_path = "naver/" + args.model_name
40
-
41
- model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device)
42
- chkpt_tag = hash_md5(weights_path)
43
-
44
- def get_context(tmp_dir):
45
- return tempfile.TemporaryDirectory(suffix='_mast3r_gradio_demo') if tmp_dir is None \
46
- else nullcontext(tmp_dir)
47
- with get_context(args.tmp_dir) as tmpdirname:
48
- cache_path = os.path.join(tmpdirname, chkpt_tag)
49
- os.makedirs(cache_path, exist_ok=True)
50
- main_demo(cache_path, model, args.retrieval_model, args.device, args.image_size, server_name, args.server_port,
51
- silent=args.silent, share=args.share, gradio_delete_cache=args.gradio_delete_cache)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1 +1,21 @@
1
- scikit-learn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ scikit-learn
2
+ torch
3
+ torchvision
4
+ roma
5
+ gradio
6
+ matplotlib
7
+ tqdm
8
+ opencv-python
9
+ scipy
10
+ einops
11
+ trimesh
12
+ tensorboard
13
+ pyglet<2
14
+ huggingface-hub[torch]>=0.22
15
+ pillow-heif # add heif/heic image support
16
+ pyrender # for rendering depths in scannetpp
17
+ kapture # for visloc data loading
18
+ kapture-localization
19
+ numpy-quaternion
20
+ pycolmap # for pnp
21
+ poselib # for pnp