File size: 2,372 Bytes
1b65314
 
 
 
 
 
 
 
 
 
 
 
16cc0ca
1b65314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16cc0ca
1b65314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
735791c
1b65314
 
 
 
16cc0ca
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import tempfile

import sys
sys.path.append(os.path.abspath('./modules'))

# import builtins
# import datetime
import argparse

from modules.pe3r.demo import main_demo
from modules.pe3r.models import Models
import torch

# def set_print_with_timestamp(time_format="%Y-%m-%d %H:%M:%S"):
#     builtin_print = builtins.print
#     def print_with_timestamp(*args, **kwargs):
#         now = datetime.datetime.now()
#         formatted_date_time = now.strftime(time_format)
#         builtin_print(f'[{formatted_date_time}] ', end='')  # print with time stamp
#         builtin_print(*args, **kwargs)
#     builtins.print = print_with_timestamp

def get_args_parser():
    parser = argparse.ArgumentParser()
    parser_url = parser.add_mutually_exclusive_group()
    parser_url.add_argument("--local_network", action='store_true', default=False,
                            help="make app accessible on local network: address will be set to 0.0.0.0")
    parser_url.add_argument("--server_name", type=str, default=None, help="server url, default is 127.0.0.1")
    parser.add_argument("--server_port", type=int, help=("will start gradio app on this port (if available). "
                                                         "If None, will search for an available port starting at 7860."), default=None)
    # parser.add_argument("--device", type=str, default='cuda', help="pytorch device")
    parser.add_argument("--tmp_dir", type=str, default=None, help="value for tempfile.tempdir")
    parser.add_argument("--silent", action='store_true', default=False, help="silence logs")
    # change defaults
    parser.prog = 'pe3r demo'
    return parser

if __name__ == '__main__':
    parser = get_args_parser()
    args = parser.parse_args()
    # set_print_with_timestamp()

    if args.tmp_dir is not None:
        tmp_path = args.tmp_dir
        os.makedirs(tmp_path, exist_ok=True)
        tempfile.tempdir = tmp_path
        
    if args.server_name is not None:
        server_name = args.server_name
    else:
        server_name = '0.0.0.0' if args.local_network else '127.0.0.1'

    pe3r = Models()

    with tempfile.TemporaryDirectory(suffix='pe3r_gradio_demo') as tmpdirname:
        if not args.silent:
            print('Outputing stuff in', tmpdirname)
        main_demo(tmpdirname, pe3r, device, server_name, args.server_port, silent=args.silent)