LittleFrog commited on
Commit
793e910
·
1 Parent(s): f17fbb8

Refactor get devices

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -90,9 +90,14 @@ description = \
90
  - You can optionally generate a high-resolution sample if the input image is of high resolution. We split the original image into `Vertical Splits` by `Horizontal Splits` patches with some `Overlaps` in between. Due to computation constraints for the online demo, we recommend `Vertical Splits` x `Horizontal Splits` to be no more than 6 and to set 2 for `Overlaps`. The denoising steps should at least be set to 80 for high resolution samples.
91
 
92
  """
 
 
 
 
93
 
94
  set_loggers("INFO")
95
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
96
 
97
  # Download from model
98
  logger.info(f"Downloading Models...")
@@ -102,10 +107,10 @@ logger.info(f"Loading Models...")
102
  model_dict = {
103
  "Albedo": InferenceModel(ckpt_path="weights/albedo",
104
  use_ddim=True,
105
- gpu_id=0),
106
  "Specular": InferenceModel(ckpt_path="weights/specular",
107
  use_ddim=True,
108
- gpu_id=0),
109
  "remove_bg": rembg.new_session(),
110
  }
111
  logger.info(f"All models Loaded!")
 
90
  - You can optionally generate a high-resolution sample if the input image is of high resolution. We split the original image into `Vertical Splits` by `Horizontal Splits` patches with some `Overlaps` in between. Due to computation constraints for the online demo, we recommend `Vertical Splits` x `Horizontal Splits` to be no more than 6 and to set 2 for `Overlaps`. The denoising steps should at least be set to 80 for high resolution samples.
91
 
92
  """
93
+ @spaces.GPU()
94
+ def get_devices():
95
+ device = "cuda" if torch.cuda.is_available() else "cpu"
96
+ return device
97
 
98
  set_loggers("INFO")
99
+ device = get_devices()
100
+ set_loggers(f"Using devices: {device}")
101
 
102
  # Download from model
103
  logger.info(f"Downloading Models...")
 
107
  model_dict = {
108
  "Albedo": InferenceModel(ckpt_path="weights/albedo",
109
  use_ddim=True,
110
+ gpu_id=device),
111
  "Specular": InferenceModel(ckpt_path="weights/specular",
112
  use_ddim=True,
113
+ gpu_id=device),
114
  "remove_bg": rembg.new_session(),
115
  }
116
  logger.info(f"All models Loaded!")