File size: 5,759 Bytes
f0c7f08 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import torch
import torch.nn as nn
import torch.nn.functional as F
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
try:
nn.init.xavier_uniform_(m.weight.data)
m.bias.data.fill_(0)
except AttributeError:
print("Skipping initialization of ", classname)
class GatedActivation(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x, y = x.chunk(2, dim=1)
return F.tanh(x) * F.sigmoid(y)
class GatedMaskedConv2d(nn.Module):
def __init__(self, mask_type, dim, kernel, residual=True, n_classes=10, bh_model=False):
super().__init__()
assert kernel % 2 == 1, print("Kernel size must be odd")
self.mask_type = mask_type
self.residual = residual
self.bh_model = bh_model
self.class_cond_embedding = nn.Embedding(
n_classes, 2 * dim
)
kernel_shp = (kernel // 2 + 1, 3 if self.bh_model else 1) # (ceil(n/2), n)
padding_shp = (kernel // 2, 1 if self.bh_model else 0)
self.vert_stack = nn.Conv2d(
dim, dim * 2,
kernel_shp, 1, padding_shp
)
self.vert_to_horiz = nn.Conv2d(2 * dim, 2 * dim, 1)
kernel_shp = (1, 2)
padding_shp = (0, 1)
self.horiz_stack = nn.Conv2d(
dim, dim * 2,
kernel_shp, 1, padding_shp
)
self.horiz_resid = nn.Conv2d(dim, dim, 1)
self.gate = GatedActivation()
def make_causal(self):
self.vert_stack.weight.data[:, :, -1].zero_() # Mask final row
self.horiz_stack.weight.data[:, :, :, -1].zero_() # Mask final column
def forward(self, x_v, x_h, h):
if self.mask_type == 'A':
self.make_causal()
h = self.class_cond_embedding(h)
h_vert = self.vert_stack(x_v)
h_vert = h_vert[:, :, :x_v.size(-2), :]
out_v = self.gate(h_vert + h[:, :, None, None])
if self.bh_model:
h_horiz = self.horiz_stack(x_h)
h_horiz = h_horiz[:, :, :, :x_h.size(-1)]
v2h = self.vert_to_horiz(h_vert)
out = self.gate(v2h + h_horiz + h[:, :, None, None])
if self.residual:
out_h = self.horiz_resid(out) + x_h
else:
out_h = self.horiz_resid(out)
else:
if self.residual:
out_v = self.horiz_resid(out_v) + x_v
else:
out_v = self.horiz_resid(out_v)
out_h = out_v
return out_v, out_h
class GatedPixelCNN(nn.Module):
def __init__(self, input_dim=256, dim=64, n_layers=15, n_classes=10, audio=False, bh_model=False):
super().__init__()
self.dim = dim
self.audio = audio
self.bh_model = bh_model
if self.audio:
self.embedding_aud = nn.Conv2d(256, dim, 1, 1, padding=0)
self.fusion_v = nn.Conv2d(dim * 2, dim, 1, 1, padding=0)
self.fusion_h = nn.Conv2d(dim * 2, dim, 1, 1, padding=0)
# Create embedding layer to embed input
self.embedding = nn.Embedding(input_dim, dim)
# Building the PixelCNN layer by layer
self.layers = nn.ModuleList()
# Initial block with Mask-A convolution
# Rest with Mask-B convolutions
for i in range(n_layers):
mask_type = 'A' if i == 0 else 'B'
kernel = 7 if i == 0 else 3
residual = False if i == 0 else True
self.layers.append(
GatedMaskedConv2d(mask_type, dim, kernel, residual, n_classes, bh_model)
)
# Add the output layer
self.output_conv = nn.Sequential(
nn.Conv2d(dim, 512, 1),
nn.ReLU(True),
nn.Conv2d(512, input_dim, 1)
)
self.apply(weights_init)
self.dp = nn.Dropout(0.1)
def forward(self, x, label, aud=None):
shp = x.size() + (-1,)
x = self.embedding(x.view(-1)).view(shp) # (B, H, W, C)
x = x.permute(0, 3, 1, 2) # (B, C, W, W)
x_v, x_h = (x, x)
for i, layer in enumerate(self.layers):
if i == 1 and self.audio is True:
aud = self.embedding_aud(aud)
a = torch.ones(aud.shape[-2]).to(aud.device)
a = self.dp(a)
aud = (aud.transpose(-1, -2) * a).transpose(-1, -2)
x_v = self.fusion_v(torch.cat([x_v, aud], dim=1))
if self.bh_model:
x_h = self.fusion_h(torch.cat([x_h, aud], dim=1))
x_v, x_h = layer(x_v, x_h, label)
if self.bh_model:
return self.output_conv(x_h)
else:
return self.output_conv(x_v)
def generate(self, label, shape=(8, 8), batch_size=64, aud_feat=None, pre_latents=None, pre_audio=None):
param = next(self.parameters())
x = torch.zeros(
(batch_size, *shape),
dtype=torch.int64, device=param.device
)
if pre_latents is not None:
x = torch.cat([pre_latents, x], dim=1)
aud_feat = torch.cat([pre_audio, aud_feat], dim=2)
h0 = pre_latents.shape[1]
h = h0 + shape[0]
else:
h0 = 0
h = shape[0]
for i in range(h0, h):
for j in range(shape[1]):
if self.audio:
logits = self.forward(x, label, aud_feat)
else:
logits = self.forward(x, label)
probs = F.softmax(logits[:, :, i, j], -1)
x.data[:, i, j].copy_(
probs.multinomial(1).squeeze().data
)
return x[:, h0:h]
|