AlekseyKorshuk commited on
Commit
6deedc6
·
1 Parent(s): 3927aba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -12
app.py CHANGED
@@ -12,26 +12,35 @@ from torchvision.utils import save_image
12
 
13
 
14
  class Generator(nn.Module):
15
- def __init__(self, nc=4, nz=100, ngf=64):
16
  super(Generator, self).__init__()
17
  self.model = nn.Sequential(
18
- nn.ConvTranspose2d(nz, ngf * 4, 3, 1, 0, bias=False),
19
- nn.BatchNorm2d(ngf * 4),
 
20
  nn.ReLU(True),
21
- nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, bias=False),
22
- nn.BatchNorm2d(ngf * 2),
 
23
  nn.ReLU(True),
24
- nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 0, bias=False),
25
- nn.BatchNorm2d(ngf),
 
26
  nn.ReLU(True),
27
- nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
28
- nn.Tanh(),
 
 
 
 
 
 
29
  )
30
 
31
- def forward(self, input):
32
- output = self.model(input)
33
- return output
34
 
 
35
 
36
  model = Generator()
37
  weights_path = hf_hub_download('huggingnft/dooggies', 'pytorch_model.bin')
 
12
 
13
 
14
  class Generator(nn.Module):
15
+ def __init__(self, num_channels=4, latent_dim=100, hidden_size=64):
16
  super(Generator, self).__init__()
17
  self.model = nn.Sequential(
18
+ # input is Z, going into a convolution
19
+ nn.ConvTranspose2d(latent_dim, hidden_size * 8, 4, 1, 0, bias=False),
20
+ nn.BatchNorm2d(hidden_size * 8),
21
  nn.ReLU(True),
22
+ # state size. (hidden_size*8) x 4 x 4
23
+ nn.ConvTranspose2d(hidden_size * 8, hidden_size * 4, 4, 2, 1, bias=False),
24
+ nn.BatchNorm2d(hidden_size * 4),
25
  nn.ReLU(True),
26
+ # state size. (hidden_size*4) x 8 x 8
27
+ nn.ConvTranspose2d(hidden_size * 4, hidden_size * 2, 4, 2, 1, bias=False),
28
+ nn.BatchNorm2d(hidden_size * 2),
29
  nn.ReLU(True),
30
+ # state size. (hidden_size*2) x 16 x 16
31
+ nn.ConvTranspose2d(hidden_size * 2, hidden_size, 4, 2, 1, bias=False),
32
+ nn.BatchNorm2d(hidden_size),
33
+ nn.ReLU(True),
34
+ # state size. (hidden_size) x 32 x 32
35
+ nn.ConvTranspose2d(hidden_size, num_channels, 4, 2, 1, bias=False),
36
+ nn.Tanh()
37
+ # state size. (num_channels) x 64 x 64
38
  )
39
 
40
+ def forward(self, noise):
41
+ pixel_values = self.model(noise)
 
42
 
43
+ return pixel_values
44
 
45
  model = Generator()
46
  weights_path = hf_hub_download('huggingnft/dooggies', 'pytorch_model.bin')