Ziqi commited on
Commit
3f68383
·
1 Parent(s): 0e8667f
Files changed (1) hide show
  1. app.py +24 -2
app.py CHANGED
@@ -16,6 +16,7 @@ import sys
16
  import os
17
  import pathlib
18
 
 
19
  import gradio as gr
20
  import torch
21
 
@@ -24,6 +25,19 @@ from inference import inference_fn
24
  # from trainer import Trainer
25
  # from uploader import upload
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  TITLE = '# ReVersion'
28
  DESCRIPTION = '''This is a demo for [https://github.com/ziqihuangg/ReVersion](https://github.com/ziqihuangg/ReVersion).
29
  It is recommended to upgrade to GPU in Settings after duplicating this space to use it.
@@ -149,6 +163,10 @@ def create_inference_demo(func: inference_fn) -> gr.Blocks:
149
  return demo
150
 
151
 
 
 
 
 
152
  with gr.Blocks(css='style.css') as demo:
153
  if os.getenv('IS_SHARED_UI'):
154
  show_warning(SHARED_UI_WARNING)
@@ -164,6 +182,10 @@ with gr.Blocks(css='style.css') as demo:
164
  with gr.TabItem('Test'):
165
  create_inference_demo(inference_fn)
166
 
167
-
168
- demo.queue(default_enabled=False).launch(share=False)
 
 
 
 
169
 
 
16
  import os
17
  import pathlib
18
 
19
+ import argparse
20
  import gradio as gr
21
  import torch
22
 
 
25
  # from trainer import Trainer
26
  # from uploader import upload
27
 
28
+
29
+ def parse_args() -> argparse.Namespace:
30
+ parser = argparse.ArgumentParser()
31
+ parser.add_argument('--device', type=str, default='cpu')
32
+ parser.add_argument('--theme', type=str)
33
+ parser.add_argument('--share', action='store_true')
34
+ parser.add_argument('--port', type=int)
35
+ parser.add_argument('--disable-queue',
36
+ dest='enable_queue',
37
+ action='store_false')
38
+ return parser.parse_args()
39
+
40
+
41
  TITLE = '# ReVersion'
42
  DESCRIPTION = '''This is a demo for [https://github.com/ziqihuangg/ReVersion](https://github.com/ziqihuangg/ReVersion).
43
  It is recommended to upgrade to GPU in Settings after duplicating this space to use it.
 
163
  return demo
164
 
165
 
166
+ args = parse_args()
167
+ args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
168
+ print('*** Now using %s.'%(args.device))
169
+
170
  with gr.Blocks(css='style.css') as demo:
171
  if os.getenv('IS_SHARED_UI'):
172
  show_warning(SHARED_UI_WARNING)
 
182
  with gr.TabItem('Test'):
183
  create_inference_demo(inference_fn)
184
 
185
+ demo.launch(
186
+ enable_queue=args.enable_queue,
187
+ server_port=args.port,
188
+ share=args.share,
189
+ )
190
+ # demo.queue(default_enabled=False).launch(server_port=args.port, share=args.share)
191