caixiaoshun's picture
使用huggingface hub尝试更新
fa7be76 verified
from typing import List
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from src.plugin.ldm.modules.diffusionmodules.model import Encoder, Decoder
from src.plugin.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
class AutoencoderKL(nn.Module):
def __init__(
self,
double_z: bool = True,
z_channels: int = 3,
resolution: int = 512,
in_channels: int = 3,
out_ch: int = 3,
ch: int = 128,
ch_mult: List = [1, 2, 4, 4],
num_res_blocks: int = 2,
attn_resolutions: List = [],
dropout: float = 0.0,
embed_dim: int = 3,
ckpt_path: str = None,
ignore_keys: List = [],
):
super(AutoencoderKL, self).__init__()
ddconfig = {
"double_z": double_z,
"z_channels": z_channels,
"resolution": resolution,
"in_channels": in_channels,
"out_ch": out_ch,
"ch": ch,
"ch_mult": ch_mult,
"num_res_blocks": num_res_blocks,
"attn_resolutions": attn_resolutions,
"dropout": dropout
}
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
assert ddconfig["double_z"]
self.quant_conv = nn.Conv2d(
2 * ddconfig["z_channels"], 2 * embed_dim, 1)
self.post_quant_conv = nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print(f"Deleting key {k} from state_dict.")
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
def encode(self, x):
h = self.encoder(x) # B, C, h, w
moments = self.quant_conv(h) # B, 6, h, w
posterior = DiagonalGaussianDistribution(moments)
return posterior # 分布
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input) # 高斯分布
if sample_posterior:
z = posterior.sample() # 采样
else:
z = posterior.mode()
dec = self.decode(z)
last_layer_weight = self.decoder.conv_out.weight
return dec, posterior, last_layer_weight
if __name__ == '__main__':
# Test the input and output shapes of the model
model = AutoencoderKL()
x = torch.randn(1, 3, 512, 512)
dec, posterior, last_layer_weight = model(x)
assert dec.shape == (1, 3, 512, 512)
assert posterior.sample().shape == posterior.mode().shape == (1, 3, 64, 64)
assert last_layer_weight.shape == (3, 128, 3, 3)
# Plot the latent space and the reconstruction from the pretrained model
model = AutoencoderKL(ckpt_path="/mnt/chongqinggeminiceph1fs/geminicephfs/wx-mm-spr-xxxx/zouxuechao/Collaborative-Diffusion/outputs/512_vae/2024-06-27T06-02-04_512_vae/checkpoints/epoch=000036.ckpt")
model.eval()
image_path = "data/celeba/image/image_512_downsampled_from_hq_1024/0.jpg"
from PIL import Image
import numpy as np
from src.data.components.celeba import DalleTransformerPreprocessor
image = Image.open(image_path).convert('RGB')
image = np.array(image).astype(np.uint8)
import copy
original = copy.deepcopy(image)
transform = DalleTransformerPreprocessor(size=512, phase='test')
image = transform(image=image)['image']
image = image.astype(np.float32)/127.5 - 1.0
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
dec, posterior, last_layer_weight = model(image)
# original image
plt.subplot(1, 3, 1)
plt.imshow(original)
plt.title("Original")
plt.axis("off")
# sampled image from the latent space
plt.subplot(1, 3, 2)
x = model.decode(posterior.sample())
x = (x+1)/2
x = x.squeeze(0).permute(1, 2, 0).cpu()
x = x.detach().numpy()
x = x.clip(0, 1)
x = (x*255).astype(np.uint8)
plt.imshow(x)
plt.title("Sampled")
plt.axis("off")
# reconstructed image
plt.subplot(1, 3, 3)
x = dec
x = (x+1)/2
x = x.squeeze(0).permute(1, 2, 0).cpu()
x = x.detach().numpy()
x = x.clip(0, 1)
x = (x*255).astype(np.uint8)
plt.imshow(x)
plt.title("Reconstructed")
plt.axis("off")
plt.tight_layout()
plt.savefig("vae_reconstruction.png")