Spaces:
Sleeping
Sleeping
caixiaoshun
commited on
Commit
·
dca1999
1
Parent(s):
2c73586
删除无用文件
Browse files- src/models/components/cnn.py +0 -26
- src/models/components/lnn.py +0 -23
- src/models/components/unet.py +0 -63
- src/models/components/vae.py +0 -152
src/models/components/cnn.py
DELETED
@@ -1,26 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch import nn
|
3 |
-
|
4 |
-
|
5 |
-
class CNN(nn.Module):
|
6 |
-
def __init__(self, dim=32):
|
7 |
-
super(CNN, self).__init__()
|
8 |
-
self.conv1 = nn.Conv2d(1, dim, 5)
|
9 |
-
self.conv2 = nn.Conv2d(dim, dim * 2, 5)
|
10 |
-
self.fc1 = nn.Linear(dim * 2 * 4 * 4, 10)
|
11 |
-
|
12 |
-
def forward(self, x):
|
13 |
-
x = torch.relu(self.conv1(x))
|
14 |
-
x = torch.max_pool2d(x, 2)
|
15 |
-
x = torch.relu(self.conv2(x))
|
16 |
-
x = torch.max_pool2d(x, 2)
|
17 |
-
x = x.view(-1, x.shape[1] * x.shape[2] * x.shape[3])
|
18 |
-
x = self.fc1(x)
|
19 |
-
return x
|
20 |
-
|
21 |
-
|
22 |
-
if __name__ == "__main__":
|
23 |
-
input = torch.randn(2, 1, 28, 28)
|
24 |
-
model = CNN()
|
25 |
-
output = model(input)
|
26 |
-
assert output.shape == (2, 10)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/components/lnn.py
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch import nn
|
3 |
-
|
4 |
-
|
5 |
-
class LNN(nn.Module):
|
6 |
-
# 创建一个全连接网络用于手写数字识别,并通过一个参数dim控制中间层的维度
|
7 |
-
def __init__(self, dim=32):
|
8 |
-
super(LNN, self).__init__()
|
9 |
-
self.fc1 = nn.Linear(28 * 28, dim)
|
10 |
-
self.fc2 = nn.Linear(dim, 10)
|
11 |
-
|
12 |
-
def forward(self, x):
|
13 |
-
x = x.view(-1, x.shape[1] * x.shape[2] * x.shape[3])
|
14 |
-
x = torch.relu(self.fc1(x))
|
15 |
-
x = self.fc2(x)
|
16 |
-
return x
|
17 |
-
|
18 |
-
|
19 |
-
if __name__ == "__main__":
|
20 |
-
input = torch.randn(2, 1, 28, 28)
|
21 |
-
model = LNN()
|
22 |
-
output = model(input)
|
23 |
-
assert output.shape == (2, 10)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/components/unet.py
DELETED
@@ -1,63 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
|
5 |
-
class UNet(nn.Module):
|
6 |
-
def __init__(self, in_channels, out_channels):
|
7 |
-
super(UNet, self).__init__()
|
8 |
-
|
9 |
-
def conv_block(in_channels, out_channels):
|
10 |
-
return nn.Sequential(
|
11 |
-
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
12 |
-
nn.ReLU(inplace=True),
|
13 |
-
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
14 |
-
nn.ReLU(inplace=True)
|
15 |
-
)
|
16 |
-
|
17 |
-
self.encoder1 = conv_block(in_channels, 64)
|
18 |
-
self.encoder2 = conv_block(64, 128)
|
19 |
-
self.encoder3 = conv_block(128, 256)
|
20 |
-
self.encoder4 = conv_block(256, 512)
|
21 |
-
self.bottleneck = conv_block(512, 1024)
|
22 |
-
|
23 |
-
self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
|
24 |
-
self.decoder4 = conv_block(1024, 512)
|
25 |
-
self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
|
26 |
-
self.decoder3 = conv_block(512, 256)
|
27 |
-
self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
|
28 |
-
self.decoder2 = conv_block(256, 128)
|
29 |
-
self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
|
30 |
-
self.decoder1 = conv_block(128, 64)
|
31 |
-
|
32 |
-
self.final = nn.Conv2d(64, out_channels, kernel_size=1)
|
33 |
-
|
34 |
-
def forward(self, x):
|
35 |
-
enc1 = self.encoder1(x)
|
36 |
-
enc2 = self.encoder2(F.max_pool2d(enc1, kernel_size=2, stride=2))
|
37 |
-
enc3 = self.encoder3(F.max_pool2d(enc2, kernel_size=2, stride=2))
|
38 |
-
enc4 = self.encoder4(F.max_pool2d(enc3, kernel_size=2, stride=2))
|
39 |
-
bottleneck = self.bottleneck(F.max_pool2d(enc4, kernel_size=2, stride=2))
|
40 |
-
|
41 |
-
dec4 = self.upconv4(bottleneck)
|
42 |
-
dec4 = torch.cat((dec4, enc4), dim=1)
|
43 |
-
dec4 = self.decoder4(dec4)
|
44 |
-
dec3 = self.upconv3(dec4)
|
45 |
-
dec3 = torch.cat((dec3, enc3), dim=1)
|
46 |
-
dec3 = self.decoder3(dec3)
|
47 |
-
dec2 = self.upconv2(dec3)
|
48 |
-
dec2 = torch.cat((dec2, enc2), dim=1)
|
49 |
-
dec2 = self.decoder2(dec2)
|
50 |
-
dec1 = self.upconv1(dec2)
|
51 |
-
dec1 = torch.cat((dec1, enc1), dim=1)
|
52 |
-
dec1 = self.decoder1(dec1)
|
53 |
-
|
54 |
-
return self.final(dec1)
|
55 |
-
|
56 |
-
if __name__ == "__main__":
|
57 |
-
model = UNet(in_channels=3,out_channels=7)
|
58 |
-
fake_img = torch.rand(size=(2,3,224,224))
|
59 |
-
print(fake_img.shape)
|
60 |
-
# torch.Size([2, 3, 224, 224])
|
61 |
-
out = model(fake_img)
|
62 |
-
print(out.shape)
|
63 |
-
# torch.Size([2, 7, 224, 224])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/components/vae.py
DELETED
@@ -1,152 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch import nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
from contextlib import contextmanager
|
5 |
-
from typing import List, Dict
|
6 |
-
from src.plugin.taming_transformers.taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
7 |
-
|
8 |
-
from src.plugin.ldm.modules.diffusionmodules.model import Encoder, Decoder
|
9 |
-
from src.plugin.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
10 |
-
|
11 |
-
import matplotlib.pyplot as plt
|
12 |
-
|
13 |
-
import torch
|
14 |
-
import torch.nn as nn
|
15 |
-
import torch.nn.functional as F
|
16 |
-
|
17 |
-
|
18 |
-
class AutoencoderKL(nn.Module):
|
19 |
-
def __init__(
|
20 |
-
self,
|
21 |
-
double_z: bool = True,
|
22 |
-
z_channels: int = 3,
|
23 |
-
resolution: int = 512,
|
24 |
-
in_channels: int = 3,
|
25 |
-
out_ch: int = 3,
|
26 |
-
ch: int = 128,
|
27 |
-
ch_mult: List = [1, 2, 4, 4],
|
28 |
-
num_res_blocks: int = 2,
|
29 |
-
attn_resolutions: List = [],
|
30 |
-
dropout: float = 0.0,
|
31 |
-
embed_dim: int = 3,
|
32 |
-
ckpt_path: str = None,
|
33 |
-
ignore_keys: List = [],
|
34 |
-
):
|
35 |
-
super(AutoencoderKL, self).__init__()
|
36 |
-
ddconfig = {
|
37 |
-
"double_z": double_z,
|
38 |
-
"z_channels": z_channels,
|
39 |
-
"resolution": resolution,
|
40 |
-
"in_channels": in_channels,
|
41 |
-
"out_ch": out_ch,
|
42 |
-
"ch": ch,
|
43 |
-
"ch_mult": ch_mult,
|
44 |
-
"num_res_blocks": num_res_blocks,
|
45 |
-
"attn_resolutions": attn_resolutions,
|
46 |
-
"dropout": dropout
|
47 |
-
}
|
48 |
-
self.encoder = Encoder(**ddconfig)
|
49 |
-
self.decoder = Decoder(**ddconfig)
|
50 |
-
assert ddconfig["double_z"]
|
51 |
-
self.quant_conv = nn.Conv2d(
|
52 |
-
2 * ddconfig["z_channels"], 2 * embed_dim, 1)
|
53 |
-
self.post_quant_conv = nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
54 |
-
self.embed_dim = embed_dim
|
55 |
-
if ckpt_path is not None:
|
56 |
-
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
57 |
-
|
58 |
-
def init_from_ckpt(self, path, ignore_keys=list()):
|
59 |
-
sd = torch.load(path, map_location="cpu")["state_dict"]
|
60 |
-
keys = list(sd.keys())
|
61 |
-
for k in keys:
|
62 |
-
for ik in ignore_keys:
|
63 |
-
if k.startswith(ik):
|
64 |
-
print(f"Deleting key {k} from state_dict.")
|
65 |
-
del sd[k]
|
66 |
-
self.load_state_dict(sd, strict=False)
|
67 |
-
print(f"Restored from {path}")
|
68 |
-
|
69 |
-
def encode(self, x):
|
70 |
-
h = self.encoder(x) # B, C, h, w
|
71 |
-
moments = self.quant_conv(h) # B, 6, h, w
|
72 |
-
posterior = DiagonalGaussianDistribution(moments)
|
73 |
-
return posterior # 分布
|
74 |
-
|
75 |
-
def decode(self, z):
|
76 |
-
z = self.post_quant_conv(z)
|
77 |
-
dec = self.decoder(z)
|
78 |
-
return dec
|
79 |
-
|
80 |
-
def forward(self, input, sample_posterior=True):
|
81 |
-
posterior = self.encode(input) # 高斯分布
|
82 |
-
if sample_posterior:
|
83 |
-
z = posterior.sample() # 采样
|
84 |
-
else:
|
85 |
-
z = posterior.mode()
|
86 |
-
dec = self.decode(z)
|
87 |
-
last_layer_weight = self.decoder.conv_out.weight
|
88 |
-
return dec, posterior, last_layer_weight
|
89 |
-
|
90 |
-
|
91 |
-
if __name__ == '__main__':
|
92 |
-
# Test the input and output shapes of the model
|
93 |
-
model = AutoencoderKL()
|
94 |
-
x = torch.randn(1, 3, 512, 512)
|
95 |
-
dec, posterior, last_layer_weight = model(x)
|
96 |
-
|
97 |
-
assert dec.shape == (1, 3, 512, 512)
|
98 |
-
assert posterior.sample().shape == posterior.mode().shape == (1, 3, 64, 64)
|
99 |
-
assert last_layer_weight.shape == (3, 128, 3, 3)
|
100 |
-
|
101 |
-
# Plot the latent space and the reconstruction from the pretrained model
|
102 |
-
model = AutoencoderKL(ckpt_path="/mnt/chongqinggeminiceph1fs/geminicephfs/wx-mm-spr-xxxx/zouxuechao/Collaborative-Diffusion/outputs/512_vae/2024-06-27T06-02-04_512_vae/checkpoints/epoch=000036.ckpt")
|
103 |
-
model.eval()
|
104 |
-
image_path = "data/celeba/image/image_512_downsampled_from_hq_1024/0.jpg"
|
105 |
-
|
106 |
-
from PIL import Image
|
107 |
-
import numpy as np
|
108 |
-
from src.data.components.celeba import DalleTransformerPreprocessor
|
109 |
-
from src.data.components.celeba import CelebA
|
110 |
-
image = Image.open(image_path).convert('RGB')
|
111 |
-
image = np.array(image).astype(np.uint8)
|
112 |
-
import copy
|
113 |
-
original = copy.deepcopy(image)
|
114 |
-
transform = DalleTransformerPreprocessor(size=512, phase='test')
|
115 |
-
image = transform(image=image)['image']
|
116 |
-
image = image.astype(np.float32)/127.5 - 1.0
|
117 |
-
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
|
118 |
-
|
119 |
-
dec, posterior, last_layer_weight = model(image)
|
120 |
-
|
121 |
-
# original image
|
122 |
-
plt.subplot(1, 3, 1)
|
123 |
-
plt.imshow(original)
|
124 |
-
plt.title("Original")
|
125 |
-
plt.axis("off")
|
126 |
-
|
127 |
-
# sampled image from the latent space
|
128 |
-
plt.subplot(1, 3, 2)
|
129 |
-
x = model.decode(posterior.sample())
|
130 |
-
x = (x+1)/2
|
131 |
-
x = x.squeeze(0).permute(1, 2, 0).cpu()
|
132 |
-
x = x.detach().numpy()
|
133 |
-
x = x.clip(0, 1)
|
134 |
-
x = (x*255).astype(np.uint8)
|
135 |
-
plt.imshow(x)
|
136 |
-
plt.title("Sampled")
|
137 |
-
plt.axis("off")
|
138 |
-
|
139 |
-
# reconstructed image
|
140 |
-
plt.subplot(1, 3, 3)
|
141 |
-
x = dec
|
142 |
-
x = (x+1)/2
|
143 |
-
x = x.squeeze(0).permute(1, 2, 0).cpu()
|
144 |
-
x = x.detach().numpy()
|
145 |
-
x = x.clip(0, 1)
|
146 |
-
x = (x*255).astype(np.uint8)
|
147 |
-
plt.imshow(x)
|
148 |
-
plt.title("Reconstructed")
|
149 |
-
plt.axis("off")
|
150 |
-
|
151 |
-
plt.tight_layout()
|
152 |
-
plt.savefig("vae_reconstruction.png")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|