yeq6x commited on
Commit
d7a562e
1 Parent(s): c9cc441

arg flag追加

Browse files
Files changed (2) hide show
  1. app.py +7 -1
  2. process_utils.py +7 -7
app.py CHANGED
@@ -3,6 +3,7 @@ from flask_socketio import SocketIO, emit
3
  from flask_cors import CORS
4
  import io
5
  import os
 
6
  from PIL import Image
7
  import torch
8
  import gc
@@ -183,5 +184,10 @@ def server_error(e):
183
  return jsonify(error=str(e)), 500
184
 
185
  if __name__ == '__main__':
186
- initialize(local_model=True)
 
 
 
 
 
187
  socketio.run(app, debug=True, host='0.0.0.0', port=5000)
 
3
  from flask_cors import CORS
4
  import io
5
  import os
6
+ import argparse
7
  from PIL import Image
8
  import torch
9
  import gc
 
184
  return jsonify(error=str(e)), 500
185
 
186
  if __name__ == '__main__':
187
+ parser = argparse.ArgumentParser(description='Server options.')
188
+ parser.add_argument('--local_model', type=bool, default=False, help='Use local model')
189
+ parser.add_argument('--use_gpu', type=bool, default=True, help='Set to True to use GPU but if not available, it will use CPU')
190
+ args = parser.parse_args()
191
+
192
+ initialize(local_model=args.local_model, use_gpu=args.use_gpu)
193
  socketio.run(app, debug=True, host='0.0.0.0', port=5000)
process_utils.py CHANGED
@@ -18,9 +18,8 @@ load_dotenv()
18
  # グローバル変数
19
  local_model = False
20
  model = None
21
- # device = "cuda" if torch.cuda.is_available() else "cpu"
22
- device = "cpu"
23
- torch_dtype = torch.float16 if device == "cuda" else torch.float32
24
  sotai_gen_pipe = None
25
  refine_gen_pipe = None
26
 
@@ -44,9 +43,10 @@ def ensure_rgb(image):
44
  return image.convert('RGB')
45
  return image
46
 
47
- def initialize(_local_model=False):
48
- global model, sotai_gen_pipe, refine_gen_pipe, local_model
49
-
 
50
  local_model = _local_model
51
  model = load_wd14_tagger_model()
52
  sotai_gen_pipe = initialize_sotai_model()
@@ -225,7 +225,7 @@ def generate_sotai_image(input_image: Image.Image, output_width: int, output_hei
225
  image=[input_image, input_image],
226
  negative_prompt=f"(wings:1.6), (clothes, garment, lighting, gray, missing limb, extra line, extra limb, extra arm, extra legs, hair, bangs, fringe, forelock, front hair, fill:1.4), (ink pool:1.6)",
227
  # negative_prompt=f"{easy_negative_v2}, (wings:1.6), (clothes, garment, lighting, gray, missing limb, extra line, extra limb, extra arm, extra legs, hair, bangs, fringe, forelock, front hair, fill:1.4), (ink pool:1.6)",
228
- num_inference_steps=40,
229
  guidance_scale=8,
230
  width=output_width,
231
  height=output_height,
 
18
  # グローバル変数
19
  local_model = False
20
  model = None
21
+ device = None
22
+ torch_dtype = None # torch.float16 if device == "cuda" else torch.float32
 
23
  sotai_gen_pipe = None
24
  refine_gen_pipe = None
25
 
 
43
  return image.convert('RGB')
44
  return image
45
 
46
+ def initialize(_local_model=False, use_gpu=True)
47
+ global model, sotai_gen_pipe, refine_gen_pipe, local_model, device, torch_dtype
48
+ device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
49
+ torch_dtype = torch.float16 if device == "cuda" else torch.float32
50
  local_model = _local_model
51
  model = load_wd14_tagger_model()
52
  sotai_gen_pipe = initialize_sotai_model()
 
225
  image=[input_image, input_image],
226
  negative_prompt=f"(wings:1.6), (clothes, garment, lighting, gray, missing limb, extra line, extra limb, extra arm, extra legs, hair, bangs, fringe, forelock, front hair, fill:1.4), (ink pool:1.6)",
227
  # negative_prompt=f"{easy_negative_v2}, (wings:1.6), (clothes, garment, lighting, gray, missing limb, extra line, extra limb, extra arm, extra legs, hair, bangs, fringe, forelock, front hair, fill:1.4), (ink pool:1.6)",
228
+ num_inference_steps=20,
229
  guidance_scale=8,
230
  width=output_width,
231
  height=output_height,