Spaces:
Sleeping
Sleeping
File size: 5,180 Bytes
f8a1225 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import pdb
from torch.nn import functional as F
from torch.nn import init
'''
'''
class Concat_embed4(nn.Module):
def __init__(self, embed_dim, projected_embed_dim):
super(Concat_embed4, self).__init__()
self.projection = nn.Sequential(
nn.Linear(in_features=embed_dim, out_features=embed_dim),
nn.BatchNorm1d(num_features=embed_dim),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(in_features=embed_dim, out_features=embed_dim),
nn.BatchNorm1d(num_features=embed_dim),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(in_features=embed_dim, out_features=projected_embed_dim),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
)
def forward(self, inp, embed):
projected_embed = self.projection(embed)
replicated_embed = projected_embed.repeat(4, 4, 1, 1).permute(2, 3, 0, 1)
hidden_concat = torch.cat([inp, replicated_embed], 1)
return hidden_concat
class generator(nn.Module):
def __init__(self):
super(generator, self).__init__()
self.image_size = 64
self.num_channels = 3
self.noise_dim = 100
self.embed_dim = 768
self.projected_embed_dim = 128
self.latent_dim = self.noise_dim + self.projected_embed_dim
self.ngf = 64
self.projection = nn.Sequential(
nn.Linear(in_features=self.embed_dim, out_features=self.embed_dim),
nn.BatchNorm1d(num_features=self.embed_dim),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(in_features=self.embed_dim, out_features=self.embed_dim),
nn.BatchNorm1d(num_features=self.embed_dim),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(in_features=self.embed_dim, out_features=self.projected_embed_dim),
nn.BatchNorm1d(num_features=self.projected_embed_dim),
nn.LeakyReLU(negative_slope=0.2, inplace=True)
)
self.netG = nn.ModuleList([
nn.ConvTranspose2d(self.latent_dim, self.ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(self.ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(self.ngf * 8, self.ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(self.ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(self.ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(self.ngf * 2, self.ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(self.ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d(self.ngf, self.num_channels, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (num_channels) x 64 x 64
])
def forward(self, embed_vector, z):
projected_embed = self.projection(embed_vector)
out = torch.cat([projected_embed.unsqueeze(2).unsqueeze(3), z], 1)
for m in self.netG:
out = m(out)
return out
class discriminator(nn.Module):
def __init__(self):
super(discriminator, self).__init__()
self.image_size = 64
self.num_channels = 3
self.embed_dim = 768
self.projected_embed_dim = 128
self.ndf = 64
self.B_dim = 128
self.C_dim = 16
self.netD_1 = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(self.num_channels, self.ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
# SelfAttention(self.ndf),
nn.Conv2d(self.ndf, self.ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(self.ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(self.ndf * 2, self.ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(self.ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(self.ndf * 4, self.ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(self.ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
)
self.projector = Concat_embed4(self.embed_dim, self.projected_embed_dim)
self.netD_2 = nn.Sequential(
# state size. (ndf*8) x 4 x 4
nn.Conv2d(self.ndf * 8 + self.projected_embed_dim,
self.ndf * 8, 1, 1, 0, bias=False),
nn.BatchNorm2d(self.ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(self.ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, inp, embed):
x_intermediate = self.netD_1(inp)
x = self.projector(x_intermediate, embed)
x = self.netD_2(x)
return x.view(-1, 1).squeeze(1), x_intermediate
|