caixiaoshun commited on
Commit
dca1999
·
1 Parent(s): 2c73586

删除无用文件

Browse files
src/models/components/cnn.py DELETED
@@ -1,26 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
-
5
- class CNN(nn.Module):
6
- def __init__(self, dim=32):
7
- super(CNN, self).__init__()
8
- self.conv1 = nn.Conv2d(1, dim, 5)
9
- self.conv2 = nn.Conv2d(dim, dim * 2, 5)
10
- self.fc1 = nn.Linear(dim * 2 * 4 * 4, 10)
11
-
12
- def forward(self, x):
13
- x = torch.relu(self.conv1(x))
14
- x = torch.max_pool2d(x, 2)
15
- x = torch.relu(self.conv2(x))
16
- x = torch.max_pool2d(x, 2)
17
- x = x.view(-1, x.shape[1] * x.shape[2] * x.shape[3])
18
- x = self.fc1(x)
19
- return x
20
-
21
-
22
- if __name__ == "__main__":
23
- input = torch.randn(2, 1, 28, 28)
24
- model = CNN()
25
- output = model(input)
26
- assert output.shape == (2, 10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/components/lnn.py DELETED
@@ -1,23 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
-
5
- class LNN(nn.Module):
6
- # 创建一个全连接网络用于手写数字识别,并通过一个参数dim控制中间层的维度
7
- def __init__(self, dim=32):
8
- super(LNN, self).__init__()
9
- self.fc1 = nn.Linear(28 * 28, dim)
10
- self.fc2 = nn.Linear(dim, 10)
11
-
12
- def forward(self, x):
13
- x = x.view(-1, x.shape[1] * x.shape[2] * x.shape[3])
14
- x = torch.relu(self.fc1(x))
15
- x = self.fc2(x)
16
- return x
17
-
18
-
19
- if __name__ == "__main__":
20
- input = torch.randn(2, 1, 28, 28)
21
- model = LNN()
22
- output = model(input)
23
- assert output.shape == (2, 10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/components/unet.py DELETED
@@ -1,63 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- class UNet(nn.Module):
6
- def __init__(self, in_channels, out_channels):
7
- super(UNet, self).__init__()
8
-
9
- def conv_block(in_channels, out_channels):
10
- return nn.Sequential(
11
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
12
- nn.ReLU(inplace=True),
13
- nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
14
- nn.ReLU(inplace=True)
15
- )
16
-
17
- self.encoder1 = conv_block(in_channels, 64)
18
- self.encoder2 = conv_block(64, 128)
19
- self.encoder3 = conv_block(128, 256)
20
- self.encoder4 = conv_block(256, 512)
21
- self.bottleneck = conv_block(512, 1024)
22
-
23
- self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
24
- self.decoder4 = conv_block(1024, 512)
25
- self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
26
- self.decoder3 = conv_block(512, 256)
27
- self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
28
- self.decoder2 = conv_block(256, 128)
29
- self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
30
- self.decoder1 = conv_block(128, 64)
31
-
32
- self.final = nn.Conv2d(64, out_channels, kernel_size=1)
33
-
34
- def forward(self, x):
35
- enc1 = self.encoder1(x)
36
- enc2 = self.encoder2(F.max_pool2d(enc1, kernel_size=2, stride=2))
37
- enc3 = self.encoder3(F.max_pool2d(enc2, kernel_size=2, stride=2))
38
- enc4 = self.encoder4(F.max_pool2d(enc3, kernel_size=2, stride=2))
39
- bottleneck = self.bottleneck(F.max_pool2d(enc4, kernel_size=2, stride=2))
40
-
41
- dec4 = self.upconv4(bottleneck)
42
- dec4 = torch.cat((dec4, enc4), dim=1)
43
- dec4 = self.decoder4(dec4)
44
- dec3 = self.upconv3(dec4)
45
- dec3 = torch.cat((dec3, enc3), dim=1)
46
- dec3 = self.decoder3(dec3)
47
- dec2 = self.upconv2(dec3)
48
- dec2 = torch.cat((dec2, enc2), dim=1)
49
- dec2 = self.decoder2(dec2)
50
- dec1 = self.upconv1(dec2)
51
- dec1 = torch.cat((dec1, enc1), dim=1)
52
- dec1 = self.decoder1(dec1)
53
-
54
- return self.final(dec1)
55
-
56
- if __name__ == "__main__":
57
- model = UNet(in_channels=3,out_channels=7)
58
- fake_img = torch.rand(size=(2,3,224,224))
59
- print(fake_img.shape)
60
- # torch.Size([2, 3, 224, 224])
61
- out = model(fake_img)
62
- print(out.shape)
63
- # torch.Size([2, 7, 224, 224])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/components/vae.py DELETED
@@ -1,152 +0,0 @@
1
- import torch
2
- from torch import nn
3
- import torch.nn.functional as F
4
- from contextlib import contextmanager
5
- from typing import List, Dict
6
- from src.plugin.taming_transformers.taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
7
-
8
- from src.plugin.ldm.modules.diffusionmodules.model import Encoder, Decoder
9
- from src.plugin.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
10
-
11
- import matplotlib.pyplot as plt
12
-
13
- import torch
14
- import torch.nn as nn
15
- import torch.nn.functional as F
16
-
17
-
18
- class AutoencoderKL(nn.Module):
19
- def __init__(
20
- self,
21
- double_z: bool = True,
22
- z_channels: int = 3,
23
- resolution: int = 512,
24
- in_channels: int = 3,
25
- out_ch: int = 3,
26
- ch: int = 128,
27
- ch_mult: List = [1, 2, 4, 4],
28
- num_res_blocks: int = 2,
29
- attn_resolutions: List = [],
30
- dropout: float = 0.0,
31
- embed_dim: int = 3,
32
- ckpt_path: str = None,
33
- ignore_keys: List = [],
34
- ):
35
- super(AutoencoderKL, self).__init__()
36
- ddconfig = {
37
- "double_z": double_z,
38
- "z_channels": z_channels,
39
- "resolution": resolution,
40
- "in_channels": in_channels,
41
- "out_ch": out_ch,
42
- "ch": ch,
43
- "ch_mult": ch_mult,
44
- "num_res_blocks": num_res_blocks,
45
- "attn_resolutions": attn_resolutions,
46
- "dropout": dropout
47
- }
48
- self.encoder = Encoder(**ddconfig)
49
- self.decoder = Decoder(**ddconfig)
50
- assert ddconfig["double_z"]
51
- self.quant_conv = nn.Conv2d(
52
- 2 * ddconfig["z_channels"], 2 * embed_dim, 1)
53
- self.post_quant_conv = nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
54
- self.embed_dim = embed_dim
55
- if ckpt_path is not None:
56
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
57
-
58
- def init_from_ckpt(self, path, ignore_keys=list()):
59
- sd = torch.load(path, map_location="cpu")["state_dict"]
60
- keys = list(sd.keys())
61
- for k in keys:
62
- for ik in ignore_keys:
63
- if k.startswith(ik):
64
- print(f"Deleting key {k} from state_dict.")
65
- del sd[k]
66
- self.load_state_dict(sd, strict=False)
67
- print(f"Restored from {path}")
68
-
69
- def encode(self, x):
70
- h = self.encoder(x) # B, C, h, w
71
- moments = self.quant_conv(h) # B, 6, h, w
72
- posterior = DiagonalGaussianDistribution(moments)
73
- return posterior # 分布
74
-
75
- def decode(self, z):
76
- z = self.post_quant_conv(z)
77
- dec = self.decoder(z)
78
- return dec
79
-
80
- def forward(self, input, sample_posterior=True):
81
- posterior = self.encode(input) # 高斯分布
82
- if sample_posterior:
83
- z = posterior.sample() # 采样
84
- else:
85
- z = posterior.mode()
86
- dec = self.decode(z)
87
- last_layer_weight = self.decoder.conv_out.weight
88
- return dec, posterior, last_layer_weight
89
-
90
-
91
- if __name__ == '__main__':
92
- # Test the input and output shapes of the model
93
- model = AutoencoderKL()
94
- x = torch.randn(1, 3, 512, 512)
95
- dec, posterior, last_layer_weight = model(x)
96
-
97
- assert dec.shape == (1, 3, 512, 512)
98
- assert posterior.sample().shape == posterior.mode().shape == (1, 3, 64, 64)
99
- assert last_layer_weight.shape == (3, 128, 3, 3)
100
-
101
- # Plot the latent space and the reconstruction from the pretrained model
102
- model = AutoencoderKL(ckpt_path="/mnt/chongqinggeminiceph1fs/geminicephfs/wx-mm-spr-xxxx/zouxuechao/Collaborative-Diffusion/outputs/512_vae/2024-06-27T06-02-04_512_vae/checkpoints/epoch=000036.ckpt")
103
- model.eval()
104
- image_path = "data/celeba/image/image_512_downsampled_from_hq_1024/0.jpg"
105
-
106
- from PIL import Image
107
- import numpy as np
108
- from src.data.components.celeba import DalleTransformerPreprocessor
109
- from src.data.components.celeba import CelebA
110
- image = Image.open(image_path).convert('RGB')
111
- image = np.array(image).astype(np.uint8)
112
- import copy
113
- original = copy.deepcopy(image)
114
- transform = DalleTransformerPreprocessor(size=512, phase='test')
115
- image = transform(image=image)['image']
116
- image = image.astype(np.float32)/127.5 - 1.0
117
- image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
118
-
119
- dec, posterior, last_layer_weight = model(image)
120
-
121
- # original image
122
- plt.subplot(1, 3, 1)
123
- plt.imshow(original)
124
- plt.title("Original")
125
- plt.axis("off")
126
-
127
- # sampled image from the latent space
128
- plt.subplot(1, 3, 2)
129
- x = model.decode(posterior.sample())
130
- x = (x+1)/2
131
- x = x.squeeze(0).permute(1, 2, 0).cpu()
132
- x = x.detach().numpy()
133
- x = x.clip(0, 1)
134
- x = (x*255).astype(np.uint8)
135
- plt.imshow(x)
136
- plt.title("Sampled")
137
- plt.axis("off")
138
-
139
- # reconstructed image
140
- plt.subplot(1, 3, 3)
141
- x = dec
142
- x = (x+1)/2
143
- x = x.squeeze(0).permute(1, 2, 0).cpu()
144
- x = x.detach().numpy()
145
- x = x.clip(0, 1)
146
- x = (x*255).astype(np.uint8)
147
- plt.imshow(x)
148
- plt.title("Reconstructed")
149
- plt.axis("off")
150
-
151
- plt.tight_layout()
152
- plt.savefig("vae_reconstruction.png")