T2I / gan_cls_768.py
DataRaptor's picture
Upload 6 files
f8a1225
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import pdb
from torch.nn import functional as F
from torch.nn import init
'''
'''
class Concat_embed4(nn.Module):
def __init__(self, embed_dim, projected_embed_dim):
super(Concat_embed4, self).__init__()
self.projection = nn.Sequential(
nn.Linear(in_features=embed_dim, out_features=embed_dim),
nn.BatchNorm1d(num_features=embed_dim),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(in_features=embed_dim, out_features=embed_dim),
nn.BatchNorm1d(num_features=embed_dim),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(in_features=embed_dim, out_features=projected_embed_dim),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
)
def forward(self, inp, embed):
projected_embed = self.projection(embed)
replicated_embed = projected_embed.repeat(4, 4, 1, 1).permute(2, 3, 0, 1)
hidden_concat = torch.cat([inp, replicated_embed], 1)
return hidden_concat
class generator(nn.Module):
def __init__(self):
super(generator, self).__init__()
self.image_size = 64
self.num_channels = 3
self.noise_dim = 100
self.embed_dim = 768
self.projected_embed_dim = 128
self.latent_dim = self.noise_dim + self.projected_embed_dim
self.ngf = 64
self.projection = nn.Sequential(
nn.Linear(in_features=self.embed_dim, out_features=self.embed_dim),
nn.BatchNorm1d(num_features=self.embed_dim),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(in_features=self.embed_dim, out_features=self.embed_dim),
nn.BatchNorm1d(num_features=self.embed_dim),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(in_features=self.embed_dim, out_features=self.projected_embed_dim),
nn.BatchNorm1d(num_features=self.projected_embed_dim),
nn.LeakyReLU(negative_slope=0.2, inplace=True)
)
self.netG = nn.ModuleList([
nn.ConvTranspose2d(self.latent_dim, self.ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(self.ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(self.ngf * 8, self.ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(self.ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(self.ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(self.ngf * 2, self.ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(self.ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d(self.ngf, self.num_channels, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (num_channels) x 64 x 64
])
def forward(self, embed_vector, z):
projected_embed = self.projection(embed_vector)
out = torch.cat([projected_embed.unsqueeze(2).unsqueeze(3), z], 1)
for m in self.netG:
out = m(out)
return out
class discriminator(nn.Module):
def __init__(self):
super(discriminator, self).__init__()
self.image_size = 64
self.num_channels = 3
self.embed_dim = 768
self.projected_embed_dim = 128
self.ndf = 64
self.B_dim = 128
self.C_dim = 16
self.netD_1 = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(self.num_channels, self.ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
# SelfAttention(self.ndf),
nn.Conv2d(self.ndf, self.ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(self.ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(self.ndf * 2, self.ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(self.ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(self.ndf * 4, self.ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(self.ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
)
self.projector = Concat_embed4(self.embed_dim, self.projected_embed_dim)
self.netD_2 = nn.Sequential(
# state size. (ndf*8) x 4 x 4
nn.Conv2d(self.ndf * 8 + self.projected_embed_dim,
self.ndf * 8, 1, 1, 0, bias=False),
nn.BatchNorm2d(self.ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(self.ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, inp, embed):
x_intermediate = self.netD_1(inp)
x = self.projector(x_intermediate, embed)
x = self.netD_2(x)
return x.view(-1, 1).squeeze(1), x_intermediate