alexnasa commited on
Commit
7130239
·
verified ·
1 Parent(s): b85f7c2

Update inference_coz.py

Browse files
Files changed (1) hide show
  1. inference_coz.py +8 -8
inference_coz.py CHANGED
@@ -175,11 +175,11 @@ if __name__ == "__main__":
175
  if not args.efficient_memory:
176
  from osediff_sd3 import OSEDiff_SD3_TEST, SD3Euler
177
  model = SD3Euler()
178
- model.text_enc_1.to('cuda:0')
179
- model.text_enc_2.to('cuda:0')
180
- model.text_enc_3.to('cuda:0')
181
- model.transformer.to('cuda:1', dtype=torch.float32)
182
- model.vae.to('cuda:1', dtype=torch.float32)
183
  for p in [model.text_enc_1, model.text_enc_2, model.text_enc_3, model.transformer, model.vae]:
184
  p.requires_grad_(False)
185
  model_test = OSEDiff_SD3_TEST(args, model)
@@ -335,9 +335,9 @@ if __name__ == "__main__":
335
  if args.efficient_memory and model is not None:
336
  print("Ensuring SR model components are on CUDA for SR inference.")
337
  if not isinstance(model_test, OSEDiff_SD3_TEST_efficient):
338
- model.text_enc_1.to('cuda:0')
339
- model.text_enc_2.to('cuda:0')
340
- model.text_enc_3.to('cuda:0')
341
  # transformer and VAE should already be on CUDA per initialization
342
  model.transformer.to('cuda', dtype=torch.float32)
343
  model.vae.to('cuda', dtype=torch.float32)
 
175
  if not args.efficient_memory:
176
  from osediff_sd3 import OSEDiff_SD3_TEST, SD3Euler
177
  model = SD3Euler()
178
+ model.text_enc_1.to('cuda')
179
+ model.text_enc_2.to('cuda')
180
+ model.text_enc_3.to('cuda')
181
+ model.transformer.to('cuda', dtype=torch.float32)
182
+ model.vae.to('cuda', dtype=torch.float32)
183
  for p in [model.text_enc_1, model.text_enc_2, model.text_enc_3, model.transformer, model.vae]:
184
  p.requires_grad_(False)
185
  model_test = OSEDiff_SD3_TEST(args, model)
 
335
  if args.efficient_memory and model is not None:
336
  print("Ensuring SR model components are on CUDA for SR inference.")
337
  if not isinstance(model_test, OSEDiff_SD3_TEST_efficient):
338
+ model.text_enc_1.to('cuda')
339
+ model.text_enc_2.to('cuda')
340
+ model.text_enc_3.to('cuda')
341
  # transformer and VAE should already be on CUDA per initialization
342
  model.transformer.to('cuda', dtype=torch.float32)
343
  model.vae.to('cuda', dtype=torch.float32)