|
from functools import partial |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from pg_modules.blocks import DownBlock, DownBlockPatch, conv2d |
|
from pg_modules.projector import F_RandomProj |
|
from pg_modules.diffaug import DiffAugment |
|
|
|
|
|
class SingleDisc(nn.Module): |
|
def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, separable=False, patch=False): |
|
super().__init__() |
|
channel_dict = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64, |
|
256: 32, 512: 16, 1024: 8} |
|
|
|
|
|
if start_sz not in channel_dict.keys(): |
|
sizes = np.array(list(channel_dict.keys())) |
|
start_sz = sizes[np.argmin(abs(sizes - start_sz))] |
|
self.start_sz = start_sz |
|
|
|
|
|
if ndf is None: |
|
nfc = channel_dict |
|
else: |
|
nfc = {k: ndf for k, v in channel_dict.items()} |
|
|
|
|
|
|
|
if nc is not None and head is None: |
|
nfc[start_sz] = nc |
|
|
|
layers = [] |
|
|
|
|
|
if head: |
|
layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False), |
|
nn.LeakyReLU(0.2, inplace=True)] |
|
|
|
|
|
DB = partial(DownBlockPatch, separable=separable) if patch else partial(DownBlock, separable=separable) |
|
while start_sz > end_sz: |
|
layers.append(DB(nfc[start_sz], nfc[start_sz//2])) |
|
start_sz = start_sz // 2 |
|
|
|
layers.append(conv2d(nfc[end_sz], 1, 4, 1, 0, bias=False)) |
|
self.main = nn.Sequential(*layers) |
|
|
|
def forward(self, x, c): |
|
return self.main(x) |
|
|
|
|
|
class SingleDiscCond(nn.Module): |
|
def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, separable=False, patch=False, c_dim=1000, cmap_dim=64, embedding_dim=128): |
|
super().__init__() |
|
self.cmap_dim = cmap_dim |
|
|
|
|
|
channel_dict = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64, |
|
256: 32, 512: 16, 1024: 8} |
|
|
|
|
|
if start_sz not in channel_dict.keys(): |
|
sizes = np.array(list(channel_dict.keys())) |
|
start_sz = sizes[np.argmin(abs(sizes - start_sz))] |
|
self.start_sz = start_sz |
|
|
|
|
|
if ndf is None: |
|
nfc = channel_dict |
|
else: |
|
nfc = {k: ndf for k, v in channel_dict.items()} |
|
|
|
|
|
|
|
if nc is not None and head is None: |
|
nfc[start_sz] = nc |
|
|
|
layers = [] |
|
|
|
|
|
if head: |
|
layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False), |
|
nn.LeakyReLU(0.2, inplace=True)] |
|
|
|
|
|
DB = partial(DownBlockPatch, separable=separable) if patch else partial(DownBlock, separable=separable) |
|
while start_sz > end_sz: |
|
layers.append(DB(nfc[start_sz], nfc[start_sz//2])) |
|
start_sz = start_sz // 2 |
|
self.main = nn.Sequential(*layers) |
|
|
|
|
|
self.cls = conv2d(nfc[end_sz], self.cmap_dim, 4, 1, 0, bias=False) |
|
self.embed = nn.Embedding(num_embeddings=c_dim, embedding_dim=embedding_dim) |
|
self.embed_proj = nn.Sequential( |
|
nn.Linear(self.embed.embedding_dim, self.cmap_dim), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
) |
|
|
|
def forward(self, x, c): |
|
h = self.main(x) |
|
out = self.cls(h) |
|
|
|
|
|
cmap = self.embed_proj(self.embed(c.argmax(1))).unsqueeze(-1).unsqueeze(-1) |
|
out = (out * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) |
|
|
|
return out |
|
|
|
|
|
class MultiScaleD(nn.Module): |
|
def __init__( |
|
self, |
|
channels, |
|
resolutions, |
|
num_discs=1, |
|
proj_type=2, |
|
cond=0, |
|
separable=False, |
|
patch=False, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
|
|
assert num_discs in [1, 2, 3, 4] |
|
|
|
|
|
self.disc_in_channels = channels[:num_discs] |
|
self.disc_in_res = resolutions[:num_discs] |
|
Disc = SingleDiscCond if cond else SingleDisc |
|
|
|
mini_discs = [] |
|
for i, (cin, res) in enumerate(zip(self.disc_in_channels, self.disc_in_res)): |
|
start_sz = res if not patch else 16 |
|
mini_discs += [str(i), Disc(nc=cin, start_sz=start_sz, end_sz=8, separable=separable, patch=patch)], |
|
self.mini_discs = nn.ModuleDict(mini_discs) |
|
|
|
def forward(self, features, c): |
|
all_logits = [] |
|
for k, disc in self.mini_discs.items(): |
|
all_logits.append(disc(features[k], c).view(features[k].size(0), -1)) |
|
|
|
all_logits = torch.cat(all_logits, dim=1) |
|
return all_logits |
|
|
|
|
|
class ProjectedDiscriminator(torch.nn.Module): |
|
def __init__( |
|
self, |
|
diffaug=True, |
|
interp224=True, |
|
backbone_kwargs={}, |
|
**kwargs |
|
): |
|
super().__init__() |
|
self.diffaug = diffaug |
|
self.interp224 = interp224 |
|
self.feature_network = F_RandomProj(**backbone_kwargs) |
|
self.discriminator = MultiScaleD( |
|
channels=self.feature_network.CHANNELS, |
|
resolutions=self.feature_network.RESOLUTIONS, |
|
**backbone_kwargs, |
|
) |
|
|
|
def train(self, mode=True): |
|
self.feature_network = self.feature_network.train(False) |
|
self.discriminator = self.discriminator.train(mode) |
|
return self |
|
|
|
def eval(self): |
|
return self.train(False) |
|
|
|
def forward(self, x, c): |
|
if self.diffaug: |
|
x = DiffAugment(x, policy='color,translation,cutout') |
|
|
|
if self.interp224: |
|
x = F.interpolate(x, 224, mode='bilinear', align_corners=False) |
|
|
|
features = self.feature_network(x) |
|
logits = self.discriminator(features, c) |
|
|
|
return logits |