venkyyuvy commited on
Commit
85a3411
1 Parent(s): 41ee56a

device fix

Browse files
Files changed (1) hide show
  1. utils.py +1 -1
utils.py CHANGED
@@ -21,7 +21,7 @@ def pil_to_latent(input_im):
21
  latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
22
  return 0.18215 * latent.latent_dist.sample()
23
 
24
- def latents_to_pil(latents, torch_device="mps:0"):
25
  # bath of latents -> list of images
26
  latents = (1 / 0.18215) * latents
27
  with torch.no_grad():
 
21
  latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
22
  return 0.18215 * latent.latent_dist.sample()
23
 
24
+ def latents_to_pil(latents, torch_device=torch_device):
25
  # bath of latents -> list of images
26
  latents = (1 / 0.18215) * latents
27
  with torch.no_grad():