Spaces:
Sleeping
Sleeping
from typing import List | |
import matplotlib.pyplot as plt | |
import torch | |
import torch.nn as nn | |
from src.plugin.ldm.modules.diffusionmodules.model import Encoder, Decoder | |
from src.plugin.ldm.modules.distributions.distributions import DiagonalGaussianDistribution | |
class AutoencoderKL(nn.Module): | |
def __init__( | |
self, | |
double_z: bool = True, | |
z_channels: int = 3, | |
resolution: int = 512, | |
in_channels: int = 3, | |
out_ch: int = 3, | |
ch: int = 128, | |
ch_mult: List = [1, 2, 4, 4], | |
num_res_blocks: int = 2, | |
attn_resolutions: List = [], | |
dropout: float = 0.0, | |
embed_dim: int = 3, | |
ckpt_path: str = None, | |
ignore_keys: List = [], | |
): | |
super(AutoencoderKL, self).__init__() | |
ddconfig = { | |
"double_z": double_z, | |
"z_channels": z_channels, | |
"resolution": resolution, | |
"in_channels": in_channels, | |
"out_ch": out_ch, | |
"ch": ch, | |
"ch_mult": ch_mult, | |
"num_res_blocks": num_res_blocks, | |
"attn_resolutions": attn_resolutions, | |
"dropout": dropout | |
} | |
self.encoder = Encoder(**ddconfig) | |
self.decoder = Decoder(**ddconfig) | |
assert ddconfig["double_z"] | |
self.quant_conv = nn.Conv2d( | |
2 * ddconfig["z_channels"], 2 * embed_dim, 1) | |
self.post_quant_conv = nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) | |
self.embed_dim = embed_dim | |
if ckpt_path is not None: | |
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) | |
def init_from_ckpt(self, path, ignore_keys=list()): | |
sd = torch.load(path, map_location="cpu")["state_dict"] | |
keys = list(sd.keys()) | |
for k in keys: | |
for ik in ignore_keys: | |
if k.startswith(ik): | |
print(f"Deleting key {k} from state_dict.") | |
del sd[k] | |
self.load_state_dict(sd, strict=False) | |
print(f"Restored from {path}") | |
def encode(self, x): | |
h = self.encoder(x) # B, C, h, w | |
moments = self.quant_conv(h) # B, 6, h, w | |
posterior = DiagonalGaussianDistribution(moments) | |
return posterior # 分布 | |
def decode(self, z): | |
z = self.post_quant_conv(z) | |
dec = self.decoder(z) | |
return dec | |
def forward(self, input, sample_posterior=True): | |
posterior = self.encode(input) # 高斯分布 | |
if sample_posterior: | |
z = posterior.sample() # 采样 | |
else: | |
z = posterior.mode() | |
dec = self.decode(z) | |
last_layer_weight = self.decoder.conv_out.weight | |
return dec, posterior, last_layer_weight | |
if __name__ == '__main__': | |
# Test the input and output shapes of the model | |
model = AutoencoderKL() | |
x = torch.randn(1, 3, 512, 512) | |
dec, posterior, last_layer_weight = model(x) | |
assert dec.shape == (1, 3, 512, 512) | |
assert posterior.sample().shape == posterior.mode().shape == (1, 3, 64, 64) | |
assert last_layer_weight.shape == (3, 128, 3, 3) | |
# Plot the latent space and the reconstruction from the pretrained model | |
model = AutoencoderKL(ckpt_path="/mnt/chongqinggeminiceph1fs/geminicephfs/wx-mm-spr-xxxx/zouxuechao/Collaborative-Diffusion/outputs/512_vae/2024-06-27T06-02-04_512_vae/checkpoints/epoch=000036.ckpt") | |
model.eval() | |
image_path = "data/celeba/image/image_512_downsampled_from_hq_1024/0.jpg" | |
from PIL import Image | |
import numpy as np | |
from src.data.components.celeba import DalleTransformerPreprocessor | |
image = Image.open(image_path).convert('RGB') | |
image = np.array(image).astype(np.uint8) | |
import copy | |
original = copy.deepcopy(image) | |
transform = DalleTransformerPreprocessor(size=512, phase='test') | |
image = transform(image=image)['image'] | |
image = image.astype(np.float32)/127.5 - 1.0 | |
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0) | |
dec, posterior, last_layer_weight = model(image) | |
# original image | |
plt.subplot(1, 3, 1) | |
plt.imshow(original) | |
plt.title("Original") | |
plt.axis("off") | |
# sampled image from the latent space | |
plt.subplot(1, 3, 2) | |
x = model.decode(posterior.sample()) | |
x = (x+1)/2 | |
x = x.squeeze(0).permute(1, 2, 0).cpu() | |
x = x.detach().numpy() | |
x = x.clip(0, 1) | |
x = (x*255).astype(np.uint8) | |
plt.imshow(x) | |
plt.title("Sampled") | |
plt.axis("off") | |
# reconstructed image | |
plt.subplot(1, 3, 3) | |
x = dec | |
x = (x+1)/2 | |
x = x.squeeze(0).permute(1, 2, 0).cpu() | |
x = x.detach().numpy() | |
x = x.clip(0, 1) | |
x = (x*255).astype(np.uint8) | |
plt.imshow(x) | |
plt.title("Reconstructed") | |
plt.axis("off") | |
plt.tight_layout() | |
plt.savefig("vae_reconstruction.png") | |