init project
Browse files
app.py
CHANGED
@@ -10,6 +10,7 @@ import argparse
|
|
10 |
|
11 |
from modules.pe3r.demo import main_demo
|
12 |
from modules.pe3r.models import Models
|
|
|
13 |
|
14 |
# def set_print_with_timestamp(time_format="%Y-%m-%d %H:%M:%S"):
|
15 |
# builtin_print = builtins.print
|
@@ -20,6 +21,8 @@ from modules.pe3r.models import Models
|
|
20 |
# builtin_print(*args, **kwargs)
|
21 |
# builtins.print = print_with_timestamp
|
22 |
|
|
|
|
|
23 |
def get_args_parser():
|
24 |
parser = argparse.ArgumentParser()
|
25 |
parser_url = parser.add_mutually_exclusive_group()
|
@@ -28,7 +31,7 @@ def get_args_parser():
|
|
28 |
parser_url.add_argument("--server_name", type=str, default=None, help="server url, default is 127.0.0.1")
|
29 |
parser.add_argument("--server_port", type=int, help=("will start gradio app on this port (if available). "
|
30 |
"If None, will search for an available port starting at 7860."), default=None)
|
31 |
-
parser.add_argument("--device", type=str, default='cuda', help="pytorch device")
|
32 |
parser.add_argument("--tmp_dir", type=str, default=None, help="value for tempfile.tempdir")
|
33 |
parser.add_argument("--silent", action='store_true', default=False, help="silence logs")
|
34 |
# change defaults
|
@@ -50,9 +53,9 @@ if __name__ == '__main__':
|
|
50 |
else:
|
51 |
server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
|
52 |
|
53 |
-
pe3r = Models(device=
|
54 |
|
55 |
with tempfile.TemporaryDirectory(suffix='pe3r_gradio_demo') as tmpdirname:
|
56 |
if not args.silent:
|
57 |
print('Outputing stuff in', tmpdirname)
|
58 |
-
main_demo(tmpdirname, pe3r,
|
|
|
10 |
|
11 |
from modules.pe3r.demo import main_demo
|
12 |
from modules.pe3r.models import Models
|
13 |
+
import torch
|
14 |
|
15 |
# def set_print_with_timestamp(time_format="%Y-%m-%d %H:%M:%S"):
|
16 |
# builtin_print = builtins.print
|
|
|
21 |
# builtin_print(*args, **kwargs)
|
22 |
# builtins.print = print_with_timestamp
|
23 |
|
24 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
25 |
+
|
26 |
def get_args_parser():
|
27 |
parser = argparse.ArgumentParser()
|
28 |
parser_url = parser.add_mutually_exclusive_group()
|
|
|
31 |
parser_url.add_argument("--server_name", type=str, default=None, help="server url, default is 127.0.0.1")
|
32 |
parser.add_argument("--server_port", type=int, help=("will start gradio app on this port (if available). "
|
33 |
"If None, will search for an available port starting at 7860."), default=None)
|
34 |
+
# parser.add_argument("--device", type=str, default='cuda', help="pytorch device")
|
35 |
parser.add_argument("--tmp_dir", type=str, default=None, help="value for tempfile.tempdir")
|
36 |
parser.add_argument("--silent", action='store_true', default=False, help="silence logs")
|
37 |
# change defaults
|
|
|
53 |
else:
|
54 |
server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
|
55 |
|
56 |
+
pe3r = Models(device=device)
|
57 |
|
58 |
with tempfile.TemporaryDirectory(suffix='pe3r_gradio_demo') as tmpdirname:
|
59 |
if not args.silent:
|
60 |
print('Outputing stuff in', tmpdirname)
|
61 |
+
main_demo(tmpdirname, pe3r, device, server_name, args.server_port, silent=args.silent)
|