BioMike commited on
Commit
2c5ec83
·
verified ·
1 Parent(s): 2c480a0

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +80 -87
model.py CHANGED
@@ -1,87 +1,80 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from transformers import PreTrainedModel, PretrainedConfig
5
-
6
-
7
- class BaseVAE(nn.Module):
8
- def __init__(self, latent_dim=16):
9
- super(BaseVAE, self).__init__()
10
- self.latent_dim = latent_dim
11
-
12
- self.encoder = nn.Sequential(
13
- nn.Conv2d(3, 32, 4, 2, 1), # 32x32 -> 16x16
14
- nn.BatchNorm2d(32),
15
- nn.ReLU(),
16
- nn.Conv2d(32, 64, 4, 2, 1), # 16x16 -> 8x8
17
- nn.BatchNorm2d(64),
18
- nn.ReLU(),
19
- nn.Conv2d(64, 128, 4, 2, 1), # 8x8 -> 4x4
20
- nn.BatchNorm2d(128),
21
- nn.ReLU(),
22
- nn.Flatten()
23
- )
24
- self.fc_mu = nn.Linear(128 * 4 * 4, latent_dim)
25
- self.fc_logvar = nn.Linear(128 * 4 * 4, latent_dim)
26
-
27
- self.decoder_input = nn.Linear(latent_dim, 128 * 4 * 4)
28
- self.decoder = nn.Sequential(
29
- nn.ConvTranspose2d(128, 64, 4, 2, 1), # 4x4 -> 8x8
30
- nn.BatchNorm2d(64),
31
- nn.ReLU(),
32
- nn.ConvTranspose2d(64, 32, 4, 2, 1), # 8x8 -> 16x16
33
- nn.BatchNorm2d(32),
34
- nn.ReLU(),
35
- nn.ConvTranspose2d(32, 3, 4, 2, 1), # 16x16 -> 32x32
36
- nn.Sigmoid()
37
- )
38
-
39
- def encode(self, x):
40
- x = self.encoder(x)
41
- mu = self.fc_mu(x)
42
- logvar = self.fc_logvar(x)
43
- return mu, logvar
44
-
45
- def reparameterize(self, mu, logvar):
46
- std = torch.exp(0.5 * logvar)
47
- eps = torch.randn_like(std)
48
- return mu + eps * std
49
-
50
- def decode(self, z):
51
- x = self.decoder_input(z)
52
- x = x.view(-1, 128, 4, 4)
53
- return self.decoder(x)
54
-
55
- def forward(self, x):
56
- mu, logvar = self.encode(x)
57
- z = self.reparameterize(mu, logvar)
58
- recon = self.decode(z)
59
- return recon, mu, logvar
60
-
61
- class VAEConfig(PretrainedConfig):
62
- model_type = "vae"
63
-
64
- def __init__(self, latent_dim=16, **kwargs):
65
- super().__init__(**kwargs)
66
- self.latent_dim = latent_dim
67
-
68
- class VAEModel(PreTrainedModel):
69
- config_class = VAEConfig
70
-
71
- def __init__(self, config):
72
- super().__init__(config)
73
- self.vae = BaseVAE(latent_dim=config.latent_dim)
74
- self.post_init()
75
-
76
- def forward(self, x):
77
- return self.vae(x)
78
-
79
- def encode(self, x):
80
- return self.vae.encode(x)
81
-
82
- def decode(self, z):
83
- return self.vae.decode(z)
84
-
85
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
86
- model = VAEModel.from_pretrained("BioMike/emoji-vae-init").to(device)
87
- model.eval()
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class BaseVAE(nn.Module):
6
+ def __init__(self, latent_dim=16):
7
+ super().__init__()
8
+ self.latent_dim = latent_dim
9
+ input_dim = 3 * 32 * 32
10
+
11
+ self.encoder = nn.Sequential(
12
+ nn.Linear(input_dim, 1024),
13
+ nn.ReLU(),
14
+ nn.Linear(1024, 512),
15
+ nn.ReLU(),
16
+ )
17
+ self.fc_mu = nn.Linear(512, latent_dim)
18
+ self.fc_logvar = nn.Linear(512, latent_dim)
19
+
20
+ self.decoder_input = nn.Linear(latent_dim, 512)
21
+ self.decoder = nn.Sequential(
22
+ nn.ReLU(),
23
+ nn.Linear(512, 1024),
24
+ nn.ReLU(),
25
+ nn.Linear(1024, input_dim),
26
+ nn.Sigmoid()
27
+ )
28
+
29
+ def encode(self, x):
30
+ x = x.view(x.size(0), -1)
31
+ x = self.encoder(x)
32
+ mu = self.fc_mu(x)
33
+ logvar = self.fc_logvar(x)
34
+ return mu, logvar
35
+
36
+ def reparameterize(self, mu, logvar):
37
+ std = torch.exp(0.5 * logvar)
38
+ eps = torch.randn_like(std)
39
+ return mu + eps * std
40
+
41
+ def decode(self, z):
42
+ x = self.decoder_input(z)
43
+ x = self.decoder(x)
44
+ x = x.view(-1, 3, 32, 32)
45
+ return x
46
+
47
+ def forward(self, x):
48
+ mu, logvar = self.encode(x)
49
+ z = self.reparameterize(mu, logvar)
50
+ recon = self.decode(z)
51
+ return recon, mu, logvar
52
+
53
+
54
+ class VAEConfig(PretrainedConfig):
55
+ model_type = "vae"
56
+
57
+ def __init__(self, latent_dim=16, **kwargs):
58
+ super().__init__(**kwargs)
59
+ self.latent_dim = latent_dim
60
+
61
+ class VAEModel(PreTrainedModel):
62
+ config_class = VAEConfig
63
+
64
+ def __init__(self, config):
65
+ super().__init__(config)
66
+ self.vae = BaseVAE(latent_dim=config.latent_dim)
67
+ self.post_init()
68
+
69
+ def forward(self, x):
70
+ return self.vae(x)
71
+
72
+ def encode(self, x):
73
+ return self.vae.encode(x)
74
+
75
+ def decode(self, z):
76
+ return self.vae.decode(z)
77
+
78
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
79
+ model = VAEModel.from_pretrained("BioMike/emoji-vae-init").to(device)
80
+ model.eval()