Upload anytext.py
Browse files- anytext.py +4 -1
anytext.py
CHANGED
@@ -822,7 +822,10 @@ class AuxiliaryLatentModule(nn.Module):
|
|
822 |
# get masked_x
|
823 |
masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint)
|
824 |
masked_img = np.transpose(masked_img, (2, 0, 1))
|
825 |
-
|
|
|
|
|
|
|
826 |
if self.use_fp16:
|
827 |
masked_img = masked_img.half()
|
828 |
masked_x = (retrieve_latents(self.vae.encode(masked_img[None, ...])) * self.vae.config.scaling_factor).detach()
|
|
|
822 |
# get masked_x
|
823 |
masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint)
|
824 |
masked_img = np.transpose(masked_img, (2, 0, 1))
|
825 |
+
print("vae device", next(self.vae.parameters()).device)
|
826 |
+
print("masked_img device", self.device)
|
827 |
+
device = next(self.vae.parameters()).device
|
828 |
+
masked_img = torch.from_numpy(masked_img.copy()).float().to(device)
|
829 |
if self.use_fp16:
|
830 |
masked_img = masked_img.half()
|
831 |
masked_x = (retrieve_latents(self.vae.encode(masked_img[None, ...])) * self.vae.config.scaling_factor).detach()
|