Spaces:
Runtime error
Runtime error
File size: 5,419 Bytes
cdee5b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
from diffusers import AutoencoderKL
import torch
import torchvision.transforms as transforms
import torch.nn.functional as F
import cv2
import numpy as np
from PIL import Image
import os
class VAE():
"""
VAE (Variational Autoencoder) class for image processing.
"""
def __init__(self, model_path="./models/sd-vae-ft-mse/", resized_img=256, use_float16=False):
"""
Initialize the VAE instance.
:param model_path: Path to the trained model.
:param resized_img: The size to which images are resized.
:param use_float16: Whether to use float16 precision.
"""
self.model_path = model_path
self.vae = AutoencoderKL.from_pretrained(self.model_path)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.vae.to(self.device)
if use_float16:
self.vae = self.vae.half()
self._use_float16 = True
else:
self._use_float16 = False
self.scaling_factor = self.vae.config.scaling_factor
self.transform = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
self._resized_img = resized_img
self._mask_tensor = self.get_mask_tensor()
def get_mask_tensor(self):
"""
Creates a mask tensor for image processing.
:return: A mask tensor.
"""
mask_tensor = torch.zeros((self._resized_img,self._resized_img))
mask_tensor[:self._resized_img//2,:] = 1
mask_tensor[mask_tensor< 0.5] = 0
mask_tensor[mask_tensor>= 0.5] = 1
return mask_tensor
def preprocess_img(self,img_name,half_mask=False):
"""
Preprocess an image for the VAE.
:param img_name: The image file path or a list of image file paths.
:param half_mask: Whether to apply a half mask to the image.
:return: A preprocessed image tensor.
"""
window = []
if isinstance(img_name, str):
window_fnames = [img_name]
for fname in window_fnames:
img = cv2.imread(fname)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (self._resized_img, self._resized_img),
interpolation=cv2.INTER_LANCZOS4)
window.append(img)
else:
img = cv2.cvtColor(img_name, cv2.COLOR_BGR2RGB)
window.append(img)
x = np.asarray(window) / 255.
x = np.transpose(x, (3, 0, 1, 2))
x = torch.squeeze(torch.FloatTensor(x))
if half_mask:
x = x * (self._mask_tensor>0.5)
x = self.transform(x)
x = x.unsqueeze(0) # [1, 3, 256, 256] torch tensor
x = x.to(self.vae.device)
return x
def encode_latents(self,image):
"""
Encode an image into latent variables.
:param image: The image tensor to encode.
:return: The encoded latent variables.
"""
with torch.no_grad():
init_latent_dist = self.vae.encode(image.to(self.vae.dtype)).latent_dist
init_latents = self.scaling_factor * init_latent_dist.sample()
return init_latents
def decode_latents(self, latents):
"""
Decode latent variables back into an image.
:param latents: The latent variables to decode.
:return: A NumPy array representing the decoded image.
"""
latents = (1/ self.scaling_factor) * latents
image = self.vae.decode(latents.to(self.vae.dtype)).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
image = (image * 255).round().astype("uint8")
image = image[...,::-1] # RGB to BGR
return image
def get_latents_for_unet(self,img):
"""
Prepare latent variables for a U-Net model.
:param img: The image to process.
:return: A concatenated tensor of latents for U-Net input.
"""
ref_image = self.preprocess_img(img,half_mask=True) # [1, 3, 256, 256] RGB, torch tensor
masked_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
ref_image = self.preprocess_img(img,half_mask=False) # [1, 3, 256, 256] RGB, torch tensor
ref_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
return latent_model_input
if __name__ == "__main__":
vae_mode_path = "./models/sd-vae-ft-mse/"
vae = VAE(model_path = vae_mode_path,use_float16=False)
img_path = "./results/sun001_crop/00000.png"
crop_imgs_path = "./results/sun001_crop/"
latents_out_path = "./results/latents/"
if not os.path.exists(latents_out_path):
os.mkdir(latents_out_path)
files = os.listdir(crop_imgs_path)
files.sort()
files = [file for file in files if file.split(".")[-1] == "png"]
for file in files:
index = file.split(".")[0]
img_path = crop_imgs_path + file
latents = vae.get_latents_for_unet(img_path)
print(img_path,"latents",latents.size())
#torch.save(latents,os.path.join(latents_out_path,index+".pt"))
#reload_tensor = torch.load('tensor.pt')
#print(reload_tensor.size())
|