Spaces:
Configuration error
Configuration error
import os | |
import random | |
import torch | |
import torch.nn as nn | |
import torch.nn.parallel | |
import torch.backends.cudnn as cudnn | |
import torch.optim as optim | |
import torch.utils.data | |
import torchvision.datasets as dset | |
import torchvision.transforms as transforms | |
import torchvision.utils as vutils | |
import numpy as np | |
# class Discriminator(nn.Module): | |
# def __init__(self, ngpu, nc = 3, ndf = 64): | |
# super(Discriminator, self).__init__() | |
# self.ngpu = ngpu | |
# self.main = nn.Sequential( | |
# # input is (nc) x 64 x 64 | |
# nn.Conv2d(nc, ndf, 4, 4, 1, bias=False), | |
# nn.LeakyReLU(0.2, inplace=True), | |
# # state size. (ndf) x 32 x 32 | |
# nn.Conv2d(ndf, ndf * 2, 4, 4, 1, bias=False), | |
# nn.BatchNorm2d(ndf * 2), | |
# nn.LeakyReLU(0.2, inplace=True), | |
# # state size. (ndf*2) x 16 x 16 | |
# nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), | |
# nn.BatchNorm2d(ndf * 4), | |
# nn.LeakyReLU(0.2, inplace=True), | |
# # state size. (ndf*4) x 8 x 8 | |
# nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), | |
# nn.BatchNorm2d(ndf * 8), | |
# nn.LeakyReLU(0.2, inplace=True), | |
# # state size. (ndf*8) x 4 x 4 | |
# nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), | |
# nn.Sigmoid() | |
# ) | |
# def forward(self, input): | |
# return self.main(input) | |
class Discriminator(torch.nn.Module): | |
def __init__(self, channels): | |
super().__init__() | |
# Filters [256, 512, 1024] | |
# Input_dim = channels (Cx64x64) | |
# Output_dim = 1 | |
self.main_module = nn.Sequential( | |
# Omitting batch normalization in critic because our new penalized training objective (WGAN with gradient penalty) is no longer valid | |
# in this setting, since we penalize the norm of the critic's gradient with respect to each input independently and not the enitre batch. | |
# There is not good & fast implementation of layer normalization --> using per instance normalization nn.InstanceNorm2d() | |
# Image (Cx32x32) | |
nn.Conv2d(in_channels=channels, out_channels=256, kernel_size=4, stride=2, padding=1), | |
nn.InstanceNorm2d(256, affine=True), | |
nn.LeakyReLU(0.2, inplace=True), | |
# State (256x16x16) | |
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1), | |
nn.InstanceNorm2d(512, affine=True), | |
nn.LeakyReLU(0.2, inplace=True), | |
# State (512x8x8) | |
nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=4, stride=2, padding=1), | |
nn.InstanceNorm2d(1024, affine=True), | |
nn.LeakyReLU(0.2, inplace=True)) | |
# output of main module --> State (1024x4x4) | |
self.output = nn.Sequential( | |
# The output of D is no longer a probability, we do not apply sigmoid at the output of D. | |
nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=4, stride=1, padding=0)) | |
def forward(self, x): | |
x = self.main_module(x) | |
return self.output(x) | |
def feature_extraction(self, x): | |
# Use discriminator for feature extraction then flatten to vector of 16384 | |
x = self.main_module(x) | |
return x.view(-1, 1024*4*4) | |