Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
class DCGAN3D_G(nn.Module): | |
def __init__(self, isize, nz, nc, ngf, ngpu, n_extra_layers=0): | |
super(DCGAN3D_G, self).__init__() | |
self.ngpu = ngpu | |
assert isize % 16 == 0, "isize has to be a multiple of 16" | |
cngf, tisize = ngf // 2, 4 | |
while tisize != isize: | |
cngf = cngf * 2 | |
tisize = tisize * 2 | |
main = nn.Sequential( | |
# input is Z, going into a convolution | |
nn.ConvTranspose3d(nz, cngf, 4, 1, 0, bias=False), | |
nn.BatchNorm3d(cngf), | |
nn.ReLU(True), | |
) | |
i, csize, cndf = 3, 4, cngf | |
while csize < isize // 2: | |
main.add_module(str(i), | |
nn.ConvTranspose3d(cngf, cngf // 2, 4, 2, 1, bias=False)) | |
main.add_module(str(i + 1), | |
nn.BatchNorm3d(cngf // 2)) | |
main.add_module(str(i + 2), | |
nn.ReLU(True)) | |
i += 3 | |
cngf = cngf // 2 | |
csize = csize * 2 | |
# Extra layers | |
for t in range(n_extra_layers): | |
main.add_module(str(i), | |
nn.Conv3d(cngf, cngf, 3, 1, 1, bias=False)) | |
main.add_module(str(i + 1), | |
nn.BatchNorm3d(cngf)) | |
main.add_module(str(i + 2), | |
nn.ReLU(True)) | |
i += 3 | |
main.add_module(str(i), | |
nn.ConvTranspose3d(cngf, nc, 4, 2, 1, bias=False)) | |
main.add_module(str(i + 1), nn.Tanh()) | |
self.main = main | |
def forward(self, input): | |
return self.main(input) | |