adaface-neurips commited on
Commit
6be3e80
·
1 Parent(s): 20fe0e9

Improve device assignment

Browse files
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -20,6 +20,9 @@ def str2bool(v):
20
  else:
21
  raise argparse.ArgumentTypeError("Boolean value expected.")
22
 
 
 
 
23
  import argparse
24
  parser = argparse.ArgumentParser()
25
  parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
@@ -68,8 +71,6 @@ base_model_path = model_style_type2base_model_path[args.model_style_type]
68
 
69
  # global variable
70
  MAX_SEED = np.iinfo(np.int32).max
71
- device = "cuda" if args.gpu is None else f"cuda:{args.gpu}"
72
- print(f"Device: {device}")
73
 
74
  global adaface
75
  adaface = None
@@ -113,6 +114,16 @@ def generate_image(image_paths, image_paths2, guidance_scale, perturb_std,
113
 
114
  global adaface
115
 
 
 
 
 
 
 
 
 
 
 
116
  adaface.to(device)
117
 
118
  if image_paths is None or len(image_paths) == 0:
 
20
  else:
21
  raise argparse.ArgumentTypeError("Boolean value expected.")
22
 
23
+ def is_running_on_spaces():
24
+ return os.getenv("SPACE_ID") is not None
25
+
26
  import argparse
27
  parser = argparse.ArgumentParser()
28
  parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
 
71
 
72
  # global variable
73
  MAX_SEED = np.iinfo(np.int32).max
 
 
74
 
75
  global adaface
76
  adaface = None
 
114
 
115
  global adaface
116
 
117
+ if is_running_on_spaces():
118
+ device = 'cuda:0'
119
+ else:
120
+ if args.gpu is None:
121
+ device = "cuda"
122
+ else:
123
+ device = f"cuda:{args.gpu}"
124
+
125
+ print(f"Device: {device}")
126
+
127
  adaface.to(device)
128
 
129
  if image_paths is None or len(image_paths) == 0: