File size: 4,929 Bytes
fa7be76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from typing import List

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from src.plugin.ldm.modules.diffusionmodules.model import Encoder, Decoder
from src.plugin.ldm.modules.distributions.distributions import DiagonalGaussianDistribution


class AutoencoderKL(nn.Module):
    def __init__(

        self,

        double_z: bool = True,

        z_channels: int = 3,

        resolution: int = 512,

        in_channels: int = 3,

        out_ch: int = 3,

        ch: int = 128,

        ch_mult: List = [1, 2, 4, 4],

        num_res_blocks: int = 2,

        attn_resolutions: List = [],

        dropout: float = 0.0,

        embed_dim: int = 3,

        ckpt_path: str = None,

        ignore_keys: List = [],

    ):
        super(AutoencoderKL, self).__init__()
        ddconfig = {
            "double_z": double_z,
            "z_channels": z_channels,
            "resolution": resolution,
            "in_channels": in_channels,
            "out_ch": out_ch,
            "ch": ch,
            "ch_mult": ch_mult,
            "num_res_blocks": num_res_blocks,
            "attn_resolutions": attn_resolutions,
            "dropout": dropout
        }
        self.encoder = Encoder(**ddconfig)
        self.decoder = Decoder(**ddconfig)
        assert ddconfig["double_z"]
        self.quant_conv = nn.Conv2d(
            2 * ddconfig["z_channels"], 2 * embed_dim, 1)
        self.post_quant_conv = nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
        self.embed_dim = embed_dim
        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)

    def init_from_ckpt(self, path, ignore_keys=list()):
        sd = torch.load(path, map_location="cpu")["state_dict"]
        keys = list(sd.keys())
        for k in keys:
            for ik in ignore_keys:
                if k.startswith(ik):
                    print(f"Deleting key {k} from state_dict.")
                    del sd[k]
        self.load_state_dict(sd, strict=False)
        print(f"Restored from {path}")

    def encode(self, x):
        h = self.encoder(x)  # B, C, h, w
        moments = self.quant_conv(h)  # B, 6, h, w
        posterior = DiagonalGaussianDistribution(moments)
        return posterior  # 分布

    def decode(self, z):
        z = self.post_quant_conv(z)
        dec = self.decoder(z)
        return dec

    def forward(self, input, sample_posterior=True):
        posterior = self.encode(input)  # 高斯分布
        if sample_posterior:
            z = posterior.sample()  # 采样
        else:
            z = posterior.mode()
        dec = self.decode(z)
        last_layer_weight = self.decoder.conv_out.weight
        return dec, posterior, last_layer_weight


if __name__ == '__main__':
    # Test the input and output shapes of the model
    model = AutoencoderKL()
    x = torch.randn(1, 3, 512, 512)
    dec, posterior, last_layer_weight = model(x)

    assert dec.shape == (1, 3, 512, 512)
    assert posterior.sample().shape == posterior.mode().shape == (1, 3, 64, 64)
    assert last_layer_weight.shape == (3, 128, 3, 3)

    # Plot the latent space and the reconstruction from the pretrained model
    model = AutoencoderKL(ckpt_path="/mnt/chongqinggeminiceph1fs/geminicephfs/wx-mm-spr-xxxx/zouxuechao/Collaborative-Diffusion/outputs/512_vae/2024-06-27T06-02-04_512_vae/checkpoints/epoch=000036.ckpt")
    model.eval()
    image_path = "data/celeba/image/image_512_downsampled_from_hq_1024/0.jpg"

    from PIL import Image
    import numpy as np
    from src.data.components.celeba import DalleTransformerPreprocessor

    image = Image.open(image_path).convert('RGB')
    image = np.array(image).astype(np.uint8)
    import copy
    original = copy.deepcopy(image)
    transform = DalleTransformerPreprocessor(size=512, phase='test')
    image = transform(image=image)['image']
    image = image.astype(np.float32)/127.5 - 1.0
    image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)

    dec, posterior, last_layer_weight = model(image)

    # original image
    plt.subplot(1, 3, 1)
    plt.imshow(original)
    plt.title("Original")
    plt.axis("off")

    # sampled image from the latent space
    plt.subplot(1, 3, 2)
    x = model.decode(posterior.sample())
    x = (x+1)/2
    x = x.squeeze(0).permute(1, 2, 0).cpu()
    x = x.detach().numpy()
    x = x.clip(0, 1)
    x = (x*255).astype(np.uint8)
    plt.imshow(x)
    plt.title("Sampled")
    plt.axis("off")

    # reconstructed image
    plt.subplot(1, 3, 3)
    x = dec
    x = (x+1)/2
    x = x.squeeze(0).permute(1, 2, 0).cpu()
    x = x.detach().numpy()
    x = x.clip(0, 1)
    x = (x*255).astype(np.uint8)
    plt.imshow(x)
    plt.title("Reconstructed")
    plt.axis("off")

    plt.tight_layout()
    plt.savefig("vae_reconstruction.png")