hujiecpp commited on
Commit
16cc0ca
·
1 Parent(s): 3b40174

init project

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -10,6 +10,7 @@ import argparse
10
 
11
  from modules.pe3r.demo import main_demo
12
  from modules.pe3r.models import Models
 
13
 
14
  # def set_print_with_timestamp(time_format="%Y-%m-%d %H:%M:%S"):
15
  # builtin_print = builtins.print
@@ -20,6 +21,8 @@ from modules.pe3r.models import Models
20
  # builtin_print(*args, **kwargs)
21
  # builtins.print = print_with_timestamp
22
 
 
 
23
  def get_args_parser():
24
  parser = argparse.ArgumentParser()
25
  parser_url = parser.add_mutually_exclusive_group()
@@ -28,7 +31,7 @@ def get_args_parser():
28
  parser_url.add_argument("--server_name", type=str, default=None, help="server url, default is 127.0.0.1")
29
  parser.add_argument("--server_port", type=int, help=("will start gradio app on this port (if available). "
30
  "If None, will search for an available port starting at 7860."), default=None)
31
- parser.add_argument("--device", type=str, default='cuda', help="pytorch device")
32
  parser.add_argument("--tmp_dir", type=str, default=None, help="value for tempfile.tempdir")
33
  parser.add_argument("--silent", action='store_true', default=False, help="silence logs")
34
  # change defaults
@@ -50,9 +53,9 @@ if __name__ == '__main__':
50
  else:
51
  server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
52
 
53
- pe3r = Models(device=args.device)
54
 
55
  with tempfile.TemporaryDirectory(suffix='pe3r_gradio_demo') as tmpdirname:
56
  if not args.silent:
57
  print('Outputing stuff in', tmpdirname)
58
- main_demo(tmpdirname, pe3r, args.device, server_name, args.server_port, silent=args.silent)
 
10
 
11
  from modules.pe3r.demo import main_demo
12
  from modules.pe3r.models import Models
13
+ import torch
14
 
15
  # def set_print_with_timestamp(time_format="%Y-%m-%d %H:%M:%S"):
16
  # builtin_print = builtins.print
 
21
  # builtin_print(*args, **kwargs)
22
  # builtins.print = print_with_timestamp
23
 
24
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
25
+
26
  def get_args_parser():
27
  parser = argparse.ArgumentParser()
28
  parser_url = parser.add_mutually_exclusive_group()
 
31
  parser_url.add_argument("--server_name", type=str, default=None, help="server url, default is 127.0.0.1")
32
  parser.add_argument("--server_port", type=int, help=("will start gradio app on this port (if available). "
33
  "If None, will search for an available port starting at 7860."), default=None)
34
+ # parser.add_argument("--device", type=str, default='cuda', help="pytorch device")
35
  parser.add_argument("--tmp_dir", type=str, default=None, help="value for tempfile.tempdir")
36
  parser.add_argument("--silent", action='store_true', default=False, help="silence logs")
37
  # change defaults
 
53
  else:
54
  server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
55
 
56
+ pe3r = Models(device=device)
57
 
58
  with tempfile.TemporaryDirectory(suffix='pe3r_gradio_demo') as tmpdirname:
59
  if not args.silent:
60
  print('Outputing stuff in', tmpdirname)
61
+ main_demo(tmpdirname, pe3r, device, server_name, args.server_port, silent=args.silent)