|
import os |
|
import json |
|
import tempfile |
|
from random import random |
|
import math |
|
from math import log2, floor |
|
from pathlib import Path |
|
from functools import partial |
|
from contextlib import contextmanager, ExitStack |
|
from pathlib import Path |
|
from shutil import rmtree |
|
|
|
import torch |
|
from torch.optim import Adam |
|
from torch import nn, einsum |
|
import torch.nn.functional as F |
|
from torch.utils.data import Dataset, DataLoader |
|
from torch.autograd import grad as torch_grad |
|
|
|
from PIL import Image |
|
import torchvision |
|
from torchvision import transforms |
|
from torchvision.utils import save_image |
|
from kornia.filters import filter2d |
|
|
|
from huggan.pytorch.lightweight_gan.diff_augment import DiffAugment |
|
|
|
from tqdm import tqdm |
|
from einops import rearrange, reduce, repeat |
|
|
|
from datasets import load_dataset |
|
|
|
from accelerate import Accelerator, DistributedDataParallelKwargs |
|
from huggingface_hub import hf_hub_download, create_repo |
|
|
|
from huggan.pytorch.huggan_mixin import HugGANModelHubMixin |
|
from huggan.utils.hub import get_full_repo_name |
|
|
|
|
|
|
|
|
|
EXTS = ['jpg', 'jpeg', 'png'] |
|
PYTORCH_WEIGHTS_NAME = 'model.pt' |
|
|
|
|
|
|
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
|
|
@contextmanager |
|
def null_context(): |
|
yield |
|
|
|
|
|
def is_power_of_two(val): |
|
return log2(val).is_integer() |
|
|
|
|
|
def default(val, d): |
|
return val if exists(val) else d |
|
|
|
|
|
def set_requires_grad(model, bool): |
|
for p in model.parameters(): |
|
p.requires_grad = bool |
|
|
|
|
|
def cycle(iterable): |
|
while True: |
|
for i in iterable: |
|
yield i |
|
|
|
|
|
def raise_if_nan(t): |
|
if torch.isnan(t): |
|
raise NanException |
|
|
|
|
|
def evaluate_in_chunks(max_batch_size, model, *args): |
|
split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args)))) |
|
chunked_outputs = [model(*i) for i in split_args] |
|
if len(chunked_outputs) == 1: |
|
return chunked_outputs[0] |
|
return torch.cat(chunked_outputs, dim=0) |
|
|
|
|
|
def slerp(val, low, high): |
|
low_norm = low / torch.norm(low, dim=1, keepdim=True) |
|
high_norm = high / torch.norm(high, dim=1, keepdim=True) |
|
omega = torch.acos((low_norm * high_norm).sum(1)) |
|
so = torch.sin(omega) |
|
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high |
|
return res |
|
|
|
|
|
def safe_div(n, d): |
|
try: |
|
res = n / d |
|
except ZeroDivisionError: |
|
prefix = '' if int(n >= 0) else '-' |
|
res = float(f'{prefix}inf') |
|
return res |
|
|
|
|
|
|
|
|
|
def gen_hinge_loss(fake, real): |
|
return fake.mean() |
|
|
|
|
|
def hinge_loss(real, fake): |
|
return (F.relu(1 + real) + F.relu(1 - fake)).mean() |
|
|
|
|
|
def dual_contrastive_loss(real_logits, fake_logits): |
|
device = real_logits.device |
|
real_logits, fake_logits = map(lambda t: rearrange(t, '... -> (...)'), (real_logits, fake_logits)) |
|
|
|
def loss_half(t1, t2): |
|
t1 = rearrange(t1, 'i -> i ()') |
|
t2 = repeat(t2, 'j -> i j', i=t1.shape[0]) |
|
t = torch.cat((t1, t2), dim=-1) |
|
return F.cross_entropy(t, torch.zeros(t1.shape[0], device=device, dtype=torch.long)) |
|
|
|
return loss_half(real_logits, fake_logits) + loss_half(-fake_logits, -real_logits) |
|
|
|
|
|
|
|
|
|
class NanException(Exception): |
|
pass |
|
|
|
|
|
class EMA(): |
|
def __init__(self, beta): |
|
super().__init__() |
|
self.beta = beta |
|
|
|
def update_average(self, old, new): |
|
if not exists(old): |
|
return new |
|
return old * self.beta + (1 - self.beta) * new |
|
|
|
|
|
class RandomApply(nn.Module): |
|
def __init__(self, prob, fn, fn_else=lambda x: x): |
|
super().__init__() |
|
self.fn = fn |
|
self.fn_else = fn_else |
|
self.prob = prob |
|
|
|
def forward(self, x): |
|
fn = self.fn if random() < self.prob else self.fn_else |
|
return fn(x) |
|
|
|
|
|
class ChanNorm(nn.Module): |
|
def __init__(self, dim, eps=1e-5): |
|
super().__init__() |
|
self.eps = eps |
|
self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) |
|
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) |
|
|
|
def forward(self, x): |
|
var = torch.var(x, dim=1, unbiased=False, keepdim=True) |
|
mean = torch.mean(x, dim=1, keepdim=True) |
|
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b |
|
|
|
|
|
class PreNorm(nn.Module): |
|
def __init__(self, dim, fn): |
|
super().__init__() |
|
self.fn = fn |
|
self.norm = ChanNorm(dim) |
|
|
|
def forward(self, x): |
|
return self.fn(self.norm(x)) |
|
|
|
|
|
class Residual(nn.Module): |
|
def __init__(self, fn): |
|
super().__init__() |
|
self.fn = fn |
|
|
|
def forward(self, x): |
|
return self.fn(x) + x |
|
|
|
|
|
class SumBranches(nn.Module): |
|
def __init__(self, branches): |
|
super().__init__() |
|
self.branches = nn.ModuleList(branches) |
|
|
|
def forward(self, x): |
|
return sum(map(lambda fn: fn(x), self.branches)) |
|
|
|
|
|
class Fuzziness(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
f = torch.Tensor([1, 2, 1]) |
|
self.register_buffer('f', f) |
|
|
|
def forward(self, x): |
|
f = self.f |
|
f = f[None, None, :] * f[None, :, None] |
|
return filter2d(x, f, normalized=True) |
|
|
|
|
|
Blur = nn.Identity |
|
|
|
|
|
|
|
|
|
class DepthWiseConv2d(nn.Module): |
|
def __init__(self, dim_in, dim_out, kernel_size, padding=0, stride=1, bias=True): |
|
super().__init__() |
|
self.net = nn.Sequential( |
|
nn.Conv2d(dim_in, dim_in, kernel_size=kernel_size, padding=padding, groups=dim_in, stride=stride, |
|
bias=bias), |
|
nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=bias) |
|
) |
|
|
|
def forward(self, x): |
|
return self.net(x) |
|
|
|
|
|
class LinearAttention(nn.Module): |
|
def __init__(self, dim, dim_head=64, heads=8): |
|
super().__init__() |
|
self.scale = dim_head ** -0.5 |
|
self.heads = heads |
|
inner_dim = dim_head * heads |
|
|
|
self.nonlin = nn.GELU() |
|
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias=False) |
|
self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding=1, bias=False) |
|
self.to_out = nn.Conv2d(inner_dim, dim, 1) |
|
|
|
def forward(self, fmap): |
|
h, x, y = self.heads, *fmap.shape[-2:] |
|
q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim=1)) |
|
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h=h), (q, k, v)) |
|
|
|
q = q.softmax(dim=-1) |
|
k = k.softmax(dim=-2) |
|
|
|
q = q * self.scale |
|
|
|
context = einsum('b n d, b n e -> b d e', k, v) |
|
out = einsum('b n d, b d e -> b n e', q, context) |
|
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h=h, x=x, y=y) |
|
|
|
out = self.nonlin(out) |
|
return self.to_out(out) |
|
|
|
|
|
|
|
|
|
def convert_image_to(img_type, image): |
|
if image.mode != img_type: |
|
return image.convert(img_type) |
|
return image |
|
|
|
|
|
class identity(object): |
|
def __call__(self, tensor): |
|
return tensor |
|
|
|
|
|
class expand_greyscale(object): |
|
def __init__(self, transparent): |
|
self.transparent = transparent |
|
|
|
def __call__(self, tensor): |
|
channels = tensor.shape[0] |
|
num_target_channels = 4 if self.transparent else 3 |
|
|
|
if channels == num_target_channels: |
|
return tensor |
|
|
|
alpha = None |
|
if channels == 1: |
|
color = tensor.expand(3, -1, -1) |
|
elif channels == 2: |
|
color = tensor[:1].expand(3, -1, -1) |
|
alpha = tensor[1:] |
|
else: |
|
raise Exception(f'image with invalid number of channels given {channels}') |
|
|
|
if not exists(alpha) and self.transparent: |
|
alpha = torch.ones(1, *tensor.shape[1:], device=tensor.device) |
|
|
|
return color if not self.transparent else torch.cat((color, alpha)) |
|
|
|
|
|
def resize_to_minimum_size(min_size, image): |
|
if max(*image.size) < min_size: |
|
return torchvision.transforms.functional.resize(image, min_size) |
|
return image |
|
|
|
|
|
|
|
|
|
def random_hflip(tensor, prob): |
|
if prob > random(): |
|
return tensor |
|
return torch.flip(tensor, dims=(3,)) |
|
|
|
|
|
class AugWrapper(nn.Module): |
|
def __init__(self, D, image_size): |
|
super().__init__() |
|
self.D = D |
|
|
|
def forward(self, images, prob=0., types=[], detach=False, **kwargs): |
|
context = torch.no_grad if detach else null_context |
|
|
|
with context(): |
|
if random() < prob: |
|
images = random_hflip(images, prob=0.5) |
|
images = DiffAugment(images, types=types) |
|
|
|
return self.D(images, **kwargs) |
|
|
|
|
|
|
|
|
|
norm_class = nn.BatchNorm2d |
|
|
|
|
|
def upsample(scale_factor=2): |
|
return nn.Upsample(scale_factor=scale_factor) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GlobalContext(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
chan_in, |
|
chan_out |
|
): |
|
super().__init__() |
|
self.to_k = nn.Conv2d(chan_in, 1, 1) |
|
chan_intermediate = max(3, chan_out // 2) |
|
|
|
self.net = nn.Sequential( |
|
nn.Conv2d(chan_in, chan_intermediate, 1), |
|
nn.LeakyReLU(0.1), |
|
nn.Conv2d(chan_intermediate, chan_out, 1), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, x): |
|
context = self.to_k(x) |
|
context = context.flatten(2).softmax(dim=-1) |
|
out = einsum('b i n, b c n -> b c i', context, x.flatten(2)) |
|
out = out.unsqueeze(-1) |
|
return self.net(out) |
|
|
|
|
|
|
|
|
|
|
|
def get_1d_dct(i, freq, L): |
|
result = math.cos(math.pi * freq * (i + 0.5) / L) / math.sqrt(L) |
|
return result * (1 if freq == 0 else math.sqrt(2)) |
|
|
|
|
|
def get_dct_weights(width, channel, fidx_u, fidx_v): |
|
dct_weights = torch.zeros(1, channel, width, width) |
|
c_part = channel // len(fidx_u) |
|
|
|
for i, (u_x, v_y) in enumerate(zip(fidx_u, fidx_v)): |
|
for x in range(width): |
|
for y in range(width): |
|
coor_value = get_1d_dct(x, u_x, width) * get_1d_dct(y, v_y, width) |
|
dct_weights[:, i * c_part: (i + 1) * c_part, x, y] = coor_value |
|
|
|
return dct_weights |
|
|
|
|
|
class FCANet(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
chan_in, |
|
chan_out, |
|
reduction=4, |
|
width |
|
): |
|
super().__init__() |
|
|
|
freq_w, freq_h = ([0] * 8), list(range(8)) |
|
dct_weights = get_dct_weights(width, chan_in, [*freq_w, *freq_h], [*freq_h, *freq_w]) |
|
self.register_buffer('dct_weights', dct_weights) |
|
|
|
chan_intermediate = max(3, chan_out // reduction) |
|
|
|
self.net = nn.Sequential( |
|
nn.Conv2d(chan_in, chan_intermediate, 1), |
|
nn.LeakyReLU(0.1), |
|
nn.Conv2d(chan_intermediate, chan_out, 1), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, x): |
|
x = reduce(x * self.dct_weights, 'b c (h h1) (w w1) -> b c h1 w1', 'sum', h1=1, w1=1) |
|
return self.net(x) |
|
|
|
|
|
|
|
|
|
class Generator(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
image_size, |
|
latent_dim=256, |
|
fmap_max=512, |
|
fmap_inverse_coef=12, |
|
transparent=False, |
|
greyscale=False, |
|
attn_res_layers=[], |
|
freq_chan_attn=False |
|
): |
|
super().__init__() |
|
resolution = log2(image_size) |
|
assert is_power_of_two(image_size), 'image size must be a power of 2' |
|
|
|
if transparent: |
|
init_channel = 4 |
|
elif greyscale: |
|
init_channel = 1 |
|
else: |
|
init_channel = 3 |
|
|
|
fmap_max = default(fmap_max, latent_dim) |
|
|
|
self.initial_conv = nn.Sequential( |
|
nn.ConvTranspose2d(latent_dim, latent_dim * 2, 4), |
|
norm_class(latent_dim * 2), |
|
nn.GLU(dim=1) |
|
) |
|
|
|
num_layers = int(resolution) - 2 |
|
features = list(map(lambda n: (n, 2 ** (fmap_inverse_coef - n)), range(2, num_layers + 2))) |
|
features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features)) |
|
features = list(map(lambda n: 3 if n[0] >= 8 else n[1], features)) |
|
features = [latent_dim, *features] |
|
|
|
in_out_features = list(zip(features[:-1], features[1:])) |
|
|
|
self.res_layers = range(2, num_layers + 2) |
|
self.layers = nn.ModuleList([]) |
|
self.res_to_feature_map = dict(zip(self.res_layers, in_out_features)) |
|
|
|
self.sle_map = ((3, 7), (4, 8), (5, 9), (6, 10)) |
|
self.sle_map = list(filter(lambda t: t[0] <= resolution and t[1] <= resolution, self.sle_map)) |
|
self.sle_map = dict(self.sle_map) |
|
|
|
self.num_layers_spatial_res = 1 |
|
|
|
for (res, (chan_in, chan_out)) in zip(self.res_layers, in_out_features): |
|
image_width = 2 ** res |
|
|
|
attn = None |
|
if image_width in attn_res_layers: |
|
attn = PreNorm(chan_in, LinearAttention(chan_in)) |
|
|
|
sle = None |
|
if res in self.sle_map: |
|
residual_layer = self.sle_map[res] |
|
sle_chan_out = self.res_to_feature_map[residual_layer - 1][-1] |
|
|
|
if freq_chan_attn: |
|
sle = FCANet( |
|
chan_in=chan_out, |
|
chan_out=sle_chan_out, |
|
width=2 ** (res + 1) |
|
) |
|
else: |
|
sle = GlobalContext( |
|
chan_in=chan_out, |
|
chan_out=sle_chan_out |
|
) |
|
|
|
layer = nn.ModuleList([ |
|
nn.Sequential( |
|
upsample(), |
|
Blur(), |
|
nn.Conv2d(chan_in, chan_out * 2, 3, padding=1), |
|
norm_class(chan_out * 2), |
|
nn.GLU(dim=1) |
|
), |
|
sle, |
|
attn |
|
]) |
|
self.layers.append(layer) |
|
|
|
self.out_conv = nn.Conv2d(features[-1], init_channel, 3, padding=1) |
|
|
|
def forward(self, x): |
|
x = rearrange(x, 'b c -> b c () ()') |
|
x = self.initial_conv(x) |
|
x = F.normalize(x, dim=1) |
|
|
|
residuals = dict() |
|
|
|
for (res, (up, sle, attn)) in zip(self.res_layers, self.layers): |
|
if exists(attn): |
|
x = attn(x) + x |
|
|
|
x = up(x) |
|
|
|
if exists(sle): |
|
out_res = self.sle_map[res] |
|
residual = sle(x) |
|
residuals[out_res] = residual |
|
|
|
next_res = res + 1 |
|
if next_res in residuals: |
|
x = x * residuals[next_res] |
|
|
|
return self.out_conv(x) |
|
|
|
|
|
class SimpleDecoder(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
chan_in, |
|
chan_out=3, |
|
num_upsamples=4, |
|
): |
|
super().__init__() |
|
|
|
self.layers = nn.ModuleList([]) |
|
final_chan = chan_out |
|
chans = chan_in |
|
|
|
for ind in range(num_upsamples): |
|
last_layer = ind == (num_upsamples - 1) |
|
chan_out = chans if not last_layer else final_chan * 2 |
|
layer = nn.Sequential( |
|
upsample(), |
|
nn.Conv2d(chans, chan_out, 3, padding=1), |
|
nn.GLU(dim=1) |
|
) |
|
self.layers.append(layer) |
|
chans //= 2 |
|
|
|
def forward(self, x): |
|
for layer in self.layers: |
|
x = layer(x) |
|
return x |
|
|
|
|
|
class Discriminator(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
image_size, |
|
fmap_max=512, |
|
fmap_inverse_coef=12, |
|
transparent=False, |
|
greyscale=False, |
|
disc_output_size=5, |
|
attn_res_layers=[] |
|
): |
|
super().__init__() |
|
resolution = log2(image_size) |
|
assert is_power_of_two(image_size), 'image size must be a power of 2' |
|
assert disc_output_size in {1, 5}, 'discriminator output dimensions can only be 5x5 or 1x1' |
|
|
|
resolution = int(resolution) |
|
|
|
if transparent: |
|
init_channel = 4 |
|
elif greyscale: |
|
init_channel = 1 |
|
else: |
|
init_channel = 3 |
|
|
|
num_non_residual_layers = max(0, int(resolution) - 8) |
|
num_residual_layers = 8 - 3 |
|
|
|
non_residual_resolutions = range(min(8, resolution), 2, -1) |
|
features = list(map(lambda n: (n, 2 ** (fmap_inverse_coef - n)), non_residual_resolutions)) |
|
features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features)) |
|
|
|
if num_non_residual_layers == 0: |
|
res, _ = features[0] |
|
features[0] = (res, init_channel) |
|
|
|
chan_in_out = list(zip(features[:-1], features[1:])) |
|
|
|
self.non_residual_layers = nn.ModuleList([]) |
|
for ind in range(num_non_residual_layers): |
|
first_layer = ind == 0 |
|
last_layer = ind == (num_non_residual_layers - 1) |
|
chan_out = features[0][-1] if last_layer else init_channel |
|
|
|
self.non_residual_layers.append(nn.Sequential( |
|
Blur(), |
|
nn.Conv2d(init_channel, chan_out, 4, stride=2, padding=1), |
|
nn.LeakyReLU(0.1) |
|
)) |
|
|
|
self.residual_layers = nn.ModuleList([]) |
|
|
|
for (res, ((_, chan_in), (_, chan_out))) in zip(non_residual_resolutions, chan_in_out): |
|
image_width = 2 ** res |
|
|
|
attn = None |
|
if image_width in attn_res_layers: |
|
attn = PreNorm(chan_in, LinearAttention(chan_in)) |
|
|
|
self.residual_layers.append(nn.ModuleList([ |
|
SumBranches([ |
|
nn.Sequential( |
|
Blur(), |
|
nn.Conv2d(chan_in, chan_out, 4, stride=2, padding=1), |
|
nn.LeakyReLU(0.1), |
|
nn.Conv2d(chan_out, chan_out, 3, padding=1), |
|
nn.LeakyReLU(0.1) |
|
), |
|
nn.Sequential( |
|
Blur(), |
|
nn.AvgPool2d(2), |
|
nn.Conv2d(chan_in, chan_out, 1), |
|
nn.LeakyReLU(0.1), |
|
) |
|
]), |
|
attn |
|
])) |
|
|
|
last_chan = features[-1][-1] |
|
if disc_output_size == 5: |
|
self.to_logits = nn.Sequential( |
|
nn.Conv2d(last_chan, last_chan, 1), |
|
nn.LeakyReLU(0.1), |
|
nn.Conv2d(last_chan, 1, 4) |
|
) |
|
elif disc_output_size == 1: |
|
self.to_logits = nn.Sequential( |
|
Blur(), |
|
nn.Conv2d(last_chan, last_chan, 3, stride=2, padding=1), |
|
nn.LeakyReLU(0.1), |
|
nn.Conv2d(last_chan, 1, 4) |
|
) |
|
|
|
self.to_shape_disc_out = nn.Sequential( |
|
nn.Conv2d(init_channel, 64, 3, padding=1), |
|
Residual(PreNorm(64, LinearAttention(64))), |
|
SumBranches([ |
|
nn.Sequential( |
|
Blur(), |
|
nn.Conv2d(64, 32, 4, stride=2, padding=1), |
|
nn.LeakyReLU(0.1), |
|
nn.Conv2d(32, 32, 3, padding=1), |
|
nn.LeakyReLU(0.1) |
|
), |
|
nn.Sequential( |
|
Blur(), |
|
nn.AvgPool2d(2), |
|
nn.Conv2d(64, 32, 1), |
|
nn.LeakyReLU(0.1), |
|
) |
|
]), |
|
Residual(PreNorm(32, LinearAttention(32))), |
|
nn.AdaptiveAvgPool2d((4, 4)), |
|
nn.Conv2d(32, 1, 4) |
|
) |
|
|
|
self.decoder1 = SimpleDecoder(chan_in=last_chan, chan_out=init_channel) |
|
self.decoder2 = SimpleDecoder(chan_in=features[-2][-1], chan_out=init_channel) if resolution >= 9 else None |
|
|
|
def forward(self, x, calc_aux_loss=False): |
|
orig_img = x |
|
|
|
for layer in self.non_residual_layers: |
|
x = layer(x) |
|
|
|
layer_outputs = [] |
|
|
|
for (net, attn) in self.residual_layers: |
|
if exists(attn): |
|
x = attn(x) + x |
|
|
|
x = net(x) |
|
layer_outputs.append(x) |
|
|
|
out = self.to_logits(x).flatten(1) |
|
|
|
img_32x32 = F.interpolate(orig_img, size=(32, 32)) |
|
out_32x32 = self.to_shape_disc_out(img_32x32) |
|
|
|
if not calc_aux_loss: |
|
return out, out_32x32, None |
|
|
|
|
|
|
|
layer_8x8 = layer_outputs[-1] |
|
layer_16x16 = layer_outputs[-2] |
|
|
|
recon_img_8x8 = self.decoder1(layer_8x8) |
|
|
|
aux_loss = F.mse_loss( |
|
recon_img_8x8, |
|
F.interpolate(orig_img, size=recon_img_8x8.shape[2:]) |
|
) |
|
|
|
if exists(self.decoder2): |
|
select_random_quadrant = lambda rand_quadrant, img: \ |
|
rearrange(img, 'b c (m h) (n w) -> (m n) b c h w', m=2, n=2)[rand_quadrant] |
|
crop_image_fn = partial(select_random_quadrant, floor(random() * 4)) |
|
img_part, layer_16x16_part = map(crop_image_fn, (orig_img, layer_16x16)) |
|
|
|
recon_img_16x16 = self.decoder2(layer_16x16_part) |
|
|
|
aux_loss_16x16 = F.mse_loss( |
|
recon_img_16x16, |
|
F.interpolate(img_part, size=recon_img_16x16.shape[2:]) |
|
) |
|
|
|
aux_loss = aux_loss + aux_loss_16x16 |
|
|
|
return out, out_32x32, aux_loss |
|
|
|
|
|
class LightweightGAN(nn.Module, HugGANModelHubMixin): |
|
def __init__( |
|
self, |
|
*, |
|
latent_dim, |
|
image_size, |
|
optimizer="adam", |
|
fmap_max=512, |
|
fmap_inverse_coef=12, |
|
transparent=False, |
|
greyscale=False, |
|
disc_output_size=5, |
|
attn_res_layers=[], |
|
freq_chan_attn=False, |
|
ttur_mult=1., |
|
lr=2e-4, |
|
): |
|
super().__init__() |
|
|
|
self.config = { |
|
'latent_dim': latent_dim, |
|
'image_size': image_size, |
|
'optimizer': optimizer, |
|
'fmap_max': fmap_max, |
|
'fmap_inverse_coef': fmap_inverse_coef, |
|
'transparent': transparent, |
|
'greyscale': greyscale, |
|
'disc_output_size': disc_output_size, |
|
'attn_res_layers': attn_res_layers, |
|
'freq_chan_attn': freq_chan_attn, |
|
'ttur_mult': ttur_mult, |
|
'lr': lr |
|
} |
|
|
|
self.latent_dim = latent_dim |
|
self.image_size = image_size |
|
|
|
G_kwargs = dict( |
|
image_size=image_size, |
|
latent_dim=latent_dim, |
|
fmap_max=fmap_max, |
|
fmap_inverse_coef=fmap_inverse_coef, |
|
transparent=transparent, |
|
greyscale=greyscale, |
|
attn_res_layers=attn_res_layers, |
|
freq_chan_attn=freq_chan_attn |
|
) |
|
|
|
self.G = Generator(**G_kwargs) |
|
|
|
self.D = Discriminator( |
|
image_size=image_size, |
|
fmap_max=fmap_max, |
|
fmap_inverse_coef=fmap_inverse_coef, |
|
transparent=transparent, |
|
greyscale=greyscale, |
|
attn_res_layers=attn_res_layers, |
|
disc_output_size=disc_output_size |
|
) |
|
|
|
self.ema_updater = EMA(0.995) |
|
self.GE = Generator(**G_kwargs) |
|
set_requires_grad(self.GE, False) |
|
|
|
if optimizer == "adam": |
|
self.G_opt = Adam(self.G.parameters(), lr=lr, betas=(0.5, 0.9)) |
|
self.D_opt = Adam(self.D.parameters(), lr=lr * ttur_mult, betas=(0.5, 0.9)) |
|
elif optimizer == "adabelief": |
|
from adabelief_pytorch import AdaBelief |
|
|
|
self.G_opt = AdaBelief(self.G.parameters(), lr=lr, betas=(0.5, 0.9)) |
|
self.D_opt = AdaBelief(self.D.parameters(), lr=lr * ttur_mult, betas=(0.5, 0.9)) |
|
else: |
|
assert False, "No valid optimizer is given" |
|
|
|
self.apply(self._init_weights) |
|
self.reset_parameter_averaging() |
|
|
|
self.D_aug = AugWrapper(self.D, image_size) |
|
|
|
def _init_weights(self, m): |
|
if type(m) in {nn.Conv2d, nn.Linear}: |
|
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') |
|
|
|
def EMA(self): |
|
def update_moving_average(ma_model, current_model): |
|
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): |
|
old_weight, up_weight = ma_params.data, current_params.data |
|
ma_params.data = self.ema_updater.update_average(old_weight, up_weight) |
|
|
|
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()): |
|
new_buffer_value = self.ema_updater.update_average(ma_buffer, current_buffer) |
|
ma_buffer.copy_(new_buffer_value) |
|
|
|
update_moving_average(self.GE, self.G) |
|
|
|
def reset_parameter_averaging(self): |
|
self.GE.load_state_dict(self.G.state_dict()) |
|
|
|
def forward(self, x): |
|
raise NotImplemented |
|
|
|
def _save_pretrained(self, save_directory): |
|
""" |
|
Overwrite this method in case you don't want to save complete model, |
|
rather some specific layers |
|
""" |
|
path = os.path.join(save_directory, PYTORCH_WEIGHTS_NAME) |
|
model_to_save = self.module if hasattr(self, "module") else self |
|
|
|
|
|
torch.save({'GAN': model_to_save.state_dict()}, path) |
|
|
|
@classmethod |
|
def _from_pretrained( |
|
cls, |
|
model_id, |
|
revision, |
|
cache_dir, |
|
force_download, |
|
proxies, |
|
resume_download, |
|
local_files_only, |
|
token, |
|
map_location="cpu", |
|
strict=False, |
|
**model_kwargs, |
|
): |
|
""" |
|
Overwrite this method in case you wish to initialize your model in a |
|
different way. |
|
""" |
|
map_location = torch.device(map_location) |
|
|
|
if os.path.isdir(model_id): |
|
print("Loading weights from local directory") |
|
model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) |
|
else: |
|
model_file = hf_hub_download( |
|
repo_id=model_id, |
|
filename=PYTORCH_WEIGHTS_NAME, |
|
revision=revision, |
|
cache_dir=cache_dir, |
|
force_download=force_download, |
|
proxies=proxies, |
|
resume_download=resume_download, |
|
token=token, |
|
local_files_only=local_files_only, |
|
) |
|
|
|
|
|
model = cls(**model_kwargs['config']) |
|
|
|
state_dict = torch.load(model_file, map_location=map_location) |
|
model.load_state_dict(state_dict["GAN"], strict=strict) |
|
model.eval() |
|
|
|
return model |
|
|
|
|
|
|
|
|
|
class Trainer(): |
|
def __init__( |
|
self, |
|
dataset_name="huggan/CelebA-faces", |
|
name='default', |
|
results_dir='results', |
|
models_dir='models', |
|
base_dir='./', |
|
optimizer='adam', |
|
latent_dim=256, |
|
image_size=128, |
|
num_image_tiles=8, |
|
fmap_max=512, |
|
transparent=False, |
|
greyscale=False, |
|
batch_size=4, |
|
gp_weight=10, |
|
gradient_accumulate_every=1, |
|
attn_res_layers=[], |
|
freq_chan_attn=False, |
|
disc_output_size=5, |
|
dual_contrast_loss=False, |
|
antialias=False, |
|
lr=2e-4, |
|
lr_mlp=1., |
|
ttur_mult=1., |
|
save_every=10000, |
|
evaluate_every=1000, |
|
aug_prob=None, |
|
aug_types=['translation', 'cutout'], |
|
dataset_aug_prob=0., |
|
calculate_fid_every=None, |
|
calculate_fid_num_images=12800, |
|
clear_fid_cache=False, |
|
log=False, |
|
cpu=False, |
|
mixed_precision="no", |
|
wandb=False, |
|
push_to_hub=False, |
|
organization_name=None, |
|
*args, |
|
**kwargs |
|
): |
|
self.GAN_params = [args, kwargs] |
|
self.GAN = None |
|
|
|
self.dataset_name = dataset_name |
|
|
|
self.name = name |
|
|
|
base_dir = Path(base_dir) |
|
self.base_dir = base_dir |
|
self.results_dir = base_dir / results_dir |
|
self.models_dir = base_dir / models_dir |
|
self.fid_dir = base_dir / 'fid' / name |
|
|
|
|
|
self.config_path = self.models_dir / name / 'config.json' |
|
|
|
assert is_power_of_two(image_size), 'image size must be a power of 2 (64, 128, 256, 512, 1024)' |
|
assert all(map(is_power_of_two, |
|
attn_res_layers)), 'resolution layers of attention must all be powers of 2 (16, 32, 64, 128, 256, 512)' |
|
|
|
assert not ( |
|
dual_contrast_loss and disc_output_size > 1), 'discriminator output size cannot be greater than 1 if using dual contrastive loss' |
|
|
|
self.image_size = image_size |
|
self.num_image_tiles = num_image_tiles |
|
|
|
self.latent_dim = latent_dim |
|
self.fmap_max = fmap_max |
|
self.transparent = transparent |
|
self.greyscale = greyscale |
|
|
|
assert (int(self.transparent) + int(self.greyscale)) < 2, 'you can only set either transparency or greyscale' |
|
|
|
self.aug_prob = aug_prob |
|
self.aug_types = aug_types |
|
|
|
self.lr = lr |
|
self.optimizer = optimizer |
|
self.ttur_mult = ttur_mult |
|
self.batch_size = batch_size |
|
self.gradient_accumulate_every = gradient_accumulate_every |
|
|
|
self.gp_weight = gp_weight |
|
|
|
self.evaluate_every = evaluate_every |
|
self.save_every = save_every |
|
self.steps = 0 |
|
|
|
self.attn_res_layers = attn_res_layers |
|
self.freq_chan_attn = freq_chan_attn |
|
|
|
self.disc_output_size = disc_output_size |
|
self.antialias = antialias |
|
|
|
self.dual_contrast_loss = dual_contrast_loss |
|
|
|
self.d_loss = 0 |
|
self.g_loss = 0 |
|
self.last_gp_loss = None |
|
self.last_recon_loss = None |
|
self.last_fid = None |
|
|
|
self.init_folders() |
|
|
|
self.loader = None |
|
self.dataset_aug_prob = dataset_aug_prob |
|
|
|
self.calculate_fid_every = calculate_fid_every |
|
self.calculate_fid_num_images = calculate_fid_num_images |
|
self.clear_fid_cache = clear_fid_cache |
|
|
|
self.syncbatchnorm = torch.cuda.device_count() > 1 and not cpu |
|
|
|
self.cpu = cpu |
|
self.mixed_precision = mixed_precision |
|
|
|
self.wandb = wandb |
|
|
|
self.push_to_hub = push_to_hub |
|
self.organization_name = organization_name |
|
self.repo_name = get_full_repo_name(self.name, self.organization_name) |
|
if self.push_to_hub: |
|
self.repo_url = create_repo(self.repo_name, exist_ok=True) |
|
|
|
@property |
|
def image_extension(self): |
|
return 'jpg' if not self.transparent else 'png' |
|
|
|
@property |
|
def checkpoint_num(self): |
|
return floor(self.steps // self.save_every) |
|
|
|
def init_GAN(self): |
|
args, kwargs = self.GAN_params |
|
|
|
|
|
|
|
global norm_class |
|
global Blur |
|
|
|
norm_class = nn.SyncBatchNorm if self.syncbatchnorm else nn.BatchNorm2d |
|
Blur = nn.Identity if not self.antialias else Fuzziness |
|
|
|
|
|
|
|
self.GAN = LightweightGAN( |
|
optimizer=self.optimizer, |
|
lr=self.lr, |
|
latent_dim=self.latent_dim, |
|
attn_res_layers=self.attn_res_layers, |
|
freq_chan_attn=self.freq_chan_attn, |
|
image_size=self.image_size, |
|
ttur_mult=self.ttur_mult, |
|
fmap_max=self.fmap_max, |
|
disc_output_size=self.disc_output_size, |
|
transparent=self.transparent, |
|
greyscale=self.greyscale, |
|
*args, |
|
**kwargs |
|
) |
|
|
|
def write_config(self): |
|
self.config_path.write_text(json.dumps(self.config())) |
|
|
|
def load_config(self): |
|
config = self.config() if not self.config_path.exists() else json.loads(self.config_path.read_text()) |
|
self.image_size = config['image_size'] |
|
self.transparent = config['transparent'] |
|
self.syncbatchnorm = config['syncbatchnorm'] |
|
self.disc_output_size = config['disc_output_size'] |
|
self.greyscale = config.pop('greyscale', False) |
|
self.attn_res_layers = config.pop('attn_res_layers', []) |
|
self.freq_chan_attn = config.pop('freq_chan_attn', False) |
|
self.optimizer = config.pop('optimizer', 'adam') |
|
self.fmap_max = config.pop('fmap_max', 512) |
|
del self.GAN |
|
self.init_GAN() |
|
|
|
def config(self): |
|
return { |
|
'image_size': self.image_size, |
|
'transparent': self.transparent, |
|
'greyscale': self.greyscale, |
|
'syncbatchnorm': self.syncbatchnorm, |
|
'disc_output_size': self.disc_output_size, |
|
'optimizer': self.optimizer, |
|
'attn_res_layers': self.attn_res_layers, |
|
'freq_chan_attn': self.freq_chan_attn |
|
} |
|
|
|
def set_data_src(self): |
|
|
|
dataset = load_dataset(self.dataset_name) |
|
|
|
if self.transparent: |
|
num_channels = 4 |
|
pillow_mode = 'RGBA' |
|
expand_fn = expand_greyscale(self.transparent) |
|
elif self.greyscale: |
|
num_channels = 1 |
|
pillow_mode = 'L' |
|
expand_fn = identity() |
|
else: |
|
num_channels = 3 |
|
pillow_mode = 'RGB' |
|
expand_fn = expand_greyscale(self.transparent) |
|
|
|
convert_image_fn = partial(convert_image_to, pillow_mode) |
|
|
|
transform = transforms.Compose([ |
|
transforms.Lambda(convert_image_fn), |
|
transforms.Lambda(partial(resize_to_minimum_size, self.image_size)), |
|
transforms.Resize(self.image_size), |
|
RandomApply(0., transforms.RandomResizedCrop(self.image_size, scale=(0.5, 1.0), ratio=(0.98, 1.02)), |
|
transforms.CenterCrop(self.image_size)), |
|
transforms.ToTensor(), |
|
transforms.Lambda(expand_fn) |
|
]) |
|
|
|
def transform_images(examples): |
|
transformed_images = [transform(image.convert("RGB")) for image in examples["image"]] |
|
|
|
examples["image"] = torch.stack(transformed_images) |
|
|
|
return examples |
|
|
|
transformed_dataset = dataset.with_transform(transform_images) |
|
|
|
per_device_batch_size = math.ceil(self.batch_size / self.accelerator.num_processes) |
|
dataloader = DataLoader(transformed_dataset["train"], per_device_batch_size, sampler=None, shuffle=False, |
|
drop_last=True, pin_memory=True) |
|
num_samples = len(transformed_dataset) |
|
|
|
|
|
|
|
self.loader = dataloader |
|
|
|
|
|
|
|
if not exists(self.aug_prob) and num_samples < 1e5: |
|
self.aug_prob = min(0.5, (1e5 - num_samples) * 3e-6) |
|
print(f'autosetting augmentation probability to {round(self.aug_prob * 100)}%') |
|
|
|
def init_accelerator(self): |
|
|
|
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) |
|
self.accelerator = Accelerator(kwargs_handlers=[ddp_kwargs], mixed_precision=self.mixed_precision, cpu=self.cpu) |
|
|
|
if self.accelerator.is_local_main_process: |
|
|
|
if self.wandb: |
|
import wandb |
|
|
|
wandb.init(project=str(self.results_dir).split("/")[-1]) |
|
|
|
if not exists(self.GAN): |
|
self.init_GAN() |
|
|
|
G = self.GAN.G |
|
D = self.GAN.D |
|
D_aug = self.GAN.D_aug |
|
|
|
|
|
|
|
self.set_data_src() |
|
|
|
|
|
G, D, D_aug, self.GAN.D_opt, self.GAN.G_opt, self.loader = self.accelerator.prepare(G, D, D_aug, self.GAN.D_opt, |
|
self.GAN.G_opt, self.loader) |
|
self.loader = cycle(self.loader) |
|
|
|
return G, D, D_aug |
|
|
|
def train(self, G, D, D_aug): |
|
assert exists(self.loader), 'You must first initialize the data source with `.set_data_src(<folder of images>)`' |
|
|
|
self.GAN.train() |
|
total_disc_loss = torch.zeros([], device=self.accelerator.device) |
|
total_gen_loss = torch.zeros([], device=self.accelerator.device) |
|
|
|
batch_size = math.ceil(self.batch_size / self.accelerator.num_processes) |
|
|
|
image_size = self.GAN.image_size |
|
latent_dim = self.GAN.latent_dim |
|
|
|
aug_prob = default(self.aug_prob, 0) |
|
aug_types = self.aug_types |
|
aug_kwargs = {'prob': aug_prob, 'types': aug_types} |
|
|
|
apply_gradient_penalty = self.steps % 4 == 0 |
|
|
|
|
|
|
|
if self.dual_contrast_loss: |
|
D_loss_fn = dual_contrastive_loss |
|
else: |
|
D_loss_fn = hinge_loss |
|
|
|
|
|
|
|
self.GAN.D_opt.zero_grad() |
|
for i in range(self.gradient_accumulate_every): |
|
latents = torch.randn(batch_size, latent_dim, device=self.accelerator.device) |
|
image_batch = next(self.loader)["image"] |
|
image_batch.requires_grad_() |
|
|
|
with torch.no_grad(): |
|
generated_images = G(latents) |
|
|
|
fake_output, fake_output_32x32, _ = D_aug(generated_images, detach=True, **aug_kwargs) |
|
|
|
real_output, real_output_32x32, real_aux_loss = D_aug(image_batch, calc_aux_loss=True, **aug_kwargs) |
|
|
|
real_output_loss = real_output |
|
fake_output_loss = fake_output |
|
|
|
divergence = D_loss_fn(real_output_loss, fake_output_loss) |
|
divergence_32x32 = D_loss_fn(real_output_32x32, fake_output_32x32) |
|
disc_loss = divergence + divergence_32x32 |
|
|
|
aux_loss = real_aux_loss |
|
disc_loss = disc_loss + aux_loss |
|
|
|
if apply_gradient_penalty: |
|
outputs = [real_output, real_output_32x32] |
|
if self.accelerator.scaler is not None: |
|
outputs = list(map(self.accelerator.scaler.scale, outputs)) |
|
|
|
scaled_gradients = torch_grad(outputs=outputs, inputs=image_batch, |
|
grad_outputs=list( |
|
map(lambda t: torch.ones(t.size(), device=self.accelerator.device), |
|
outputs)), |
|
create_graph=True, retain_graph=True, only_inputs=True)[0] |
|
|
|
inv_scale = 1. |
|
if self.accelerator.scaler is not None: |
|
inv_scale = safe_div(1., self.accelerator.scaler.get_scale()) |
|
|
|
if inv_scale != float('inf'): |
|
gradients = scaled_gradients * inv_scale |
|
|
|
gradients = gradients.reshape(batch_size, -1) |
|
gp = self.gp_weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean() |
|
|
|
if not torch.isnan(gp): |
|
disc_loss = disc_loss + gp |
|
self.last_gp_loss = gp.clone().detach().item() |
|
|
|
|
|
|
|
disc_loss = disc_loss / self.gradient_accumulate_every |
|
|
|
disc_loss.register_hook(raise_if_nan) |
|
self.accelerator.backward(disc_loss) |
|
total_disc_loss += divergence |
|
|
|
self.last_recon_loss = aux_loss.item() |
|
self.d_loss = float(total_disc_loss.item() / self.gradient_accumulate_every) |
|
self.GAN.D_opt.step() |
|
|
|
|
|
|
|
if self.dual_contrast_loss: |
|
G_loss_fn = dual_contrastive_loss |
|
G_requires_calc_real = True |
|
else: |
|
G_loss_fn = gen_hinge_loss |
|
G_requires_calc_real = False |
|
|
|
|
|
|
|
self.GAN.G_opt.zero_grad() |
|
|
|
for i in range(self.gradient_accumulate_every): |
|
latents = torch.randn(batch_size, latent_dim, device=self.accelerator.device) |
|
|
|
if G_requires_calc_real: |
|
image_batch = next(self.loader)["image"] |
|
image_batch.requires_grad_() |
|
|
|
generated_images = G(latents) |
|
|
|
fake_output, fake_output_32x32, _ = D_aug(generated_images, **aug_kwargs) |
|
real_output, real_output_32x32, _ = D_aug(image_batch, **aug_kwargs) if G_requires_calc_real else ( |
|
None, None, None) |
|
|
|
loss = G_loss_fn(fake_output, real_output) |
|
loss_32x32 = G_loss_fn(fake_output_32x32, real_output_32x32) |
|
|
|
gen_loss = loss + loss_32x32 |
|
|
|
gen_loss = gen_loss / self.gradient_accumulate_every |
|
|
|
gen_loss.register_hook(raise_if_nan) |
|
self.accelerator.backward(gen_loss) |
|
total_gen_loss += loss |
|
|
|
|
|
|
|
self.g_loss = float(total_gen_loss.item() / self.gradient_accumulate_every) |
|
self.GAN.G_opt.step() |
|
|
|
|
|
if self.accelerator.is_main_process and self.steps % 10 == 0 and self.steps > 20000: |
|
self.GAN.EMA() |
|
|
|
if self.accelerator.is_main_process and self.steps <= 25000 and self.steps % 1000 == 2: |
|
self.GAN.reset_parameter_averaging() |
|
|
|
|
|
|
|
if any(torch.isnan(l) for l in (total_gen_loss, total_disc_loss)): |
|
print(f'NaN detected for generator or discriminator. Loading from checkpoint #{self.checkpoint_num}') |
|
self.load(self.checkpoint_num) |
|
raise NanException |
|
|
|
del total_disc_loss |
|
del total_gen_loss |
|
|
|
|
|
|
|
if self.accelerator.is_main_process: |
|
if self.steps % self.save_every == 0: |
|
self.save(self.checkpoint_num) |
|
|
|
if self.push_to_hub: |
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
self.GAN.push_to_hub(temp_dir, self.repo_url, config=self.GAN.config, skip_lfs_files=True) |
|
|
|
if self.steps % self.evaluate_every == 0 or (self.steps % 100 == 0 and self.steps < 20000): |
|
self.evaluate(floor(self.steps / self.evaluate_every), num_image_tiles=self.num_image_tiles) |
|
|
|
if exists(self.calculate_fid_every) and self.steps % self.calculate_fid_every == 0 and self.steps != 0: |
|
num_batches = math.ceil(self.calculate_fid_num_images / self.batch_size) |
|
fid = self.calculate_fid(num_batches) |
|
self.last_fid = fid |
|
|
|
with open(str(self.results_dir / self.name / f'fid_scores.txt'), 'a') as f: |
|
f.write(f'{self.steps},{fid}\n') |
|
|
|
self.steps += 1 |
|
|
|
@torch.no_grad() |
|
def evaluate(self, num=0, num_image_tiles=4): |
|
self.GAN.eval() |
|
|
|
ext = self.image_extension |
|
num_rows = num_image_tiles |
|
|
|
latent_dim = self.GAN.latent_dim |
|
image_size = self.GAN.image_size |
|
|
|
|
|
|
|
latents = torch.randn(num_rows ** 2, latent_dim, device=self.accelerator.device) |
|
|
|
|
|
|
|
generated_images = self.generate_(self.GAN.G, latents) |
|
file_name = str(self.results_dir / self.name / f'{str(num)}.{ext}') |
|
save_image(generated_images, file_name, nrow=num_rows) |
|
|
|
|
|
|
|
generated_images = self.generate_(self.GAN.GE.to(self.accelerator.device), latents) |
|
file_name_ema = str(self.results_dir / self.name / f'{str(num)}-ema.{ext}') |
|
save_image(generated_images, file_name_ema, nrow=num_rows) |
|
|
|
if self.accelerator.is_local_main_process and self.wandb: |
|
import wandb |
|
|
|
wandb.log({'generated_examples': wandb.Image(str(file_name))}) |
|
wandb.log({'generated_examples_ema': wandb.Image(str(file_name_ema))}) |
|
|
|
@torch.no_grad() |
|
def generate(self, num=0, num_image_tiles=4, checkpoint=None, types=['default', 'ema']): |
|
self.GAN.eval() |
|
|
|
latent_dim = self.GAN.latent_dim |
|
dir_name = self.name + str('-generated-') + str(checkpoint) |
|
dir_full = Path().absolute() / self.results_dir / dir_name |
|
ext = self.image_extension |
|
|
|
if not dir_full.exists(): |
|
os.mkdir(dir_full) |
|
|
|
|
|
if 'default' in types: |
|
for i in tqdm(range(num_image_tiles), desc='Saving generated default images'): |
|
latents = torch.randn(1, latent_dim, device=self.accelerator.device) |
|
generated_image = self.generate_(self.GAN.G, latents) |
|
path = str(self.results_dir / dir_name / f'{str(num)}-{str(i)}.{ext}') |
|
save_image(generated_image[0], path, nrow=1) |
|
|
|
|
|
if 'ema' in types: |
|
for i in tqdm(range(num_image_tiles), desc='Saving generated EMA images'): |
|
latents = torch.randn(1, latent_dim, device=self.accelerator.device) |
|
generated_image = self.generate_(self.GAN.GE, latents) |
|
path = str(self.results_dir / dir_name / f'{str(num)}-{str(i)}-ema.{ext}') |
|
save_image(generated_image[0], path, nrow=1) |
|
|
|
return dir_full |
|
|
|
@torch.no_grad() |
|
def show_progress(self, num_images=4, types=['default', 'ema']): |
|
checkpoints = self.get_checkpoints() |
|
assert exists(checkpoints), 'cannot find any checkpoints to create a training progress video for' |
|
|
|
dir_name = self.name + str('-progress') |
|
dir_full = Path().absolute() / self.results_dir / dir_name |
|
ext = self.image_extension |
|
latents = None |
|
|
|
zfill_length = math.ceil(math.log10(len(checkpoints))) |
|
|
|
if not dir_full.exists(): |
|
os.mkdir(dir_full) |
|
|
|
for checkpoint in tqdm(checkpoints, desc='Generating progress images'): |
|
self.load(checkpoint, print_version=False) |
|
self.GAN.eval() |
|
|
|
if checkpoint == 0: |
|
latents = torch.randn(num_images, self.GAN.latent_dim, self.accelerator.device) |
|
|
|
|
|
if 'default' in types: |
|
generated_image = self.generate_(self.GAN.G, latents) |
|
path = str(self.results_dir / dir_name / f'{str(checkpoint).zfill(zfill_length)}.{ext}') |
|
save_image(generated_image, path, nrow=num_images) |
|
|
|
|
|
if 'ema' in types: |
|
generated_image = self.generate_(self.GAN.GE, latents) |
|
path = str(self.results_dir / dir_name / f'{str(checkpoint).zfill(zfill_length)}-ema.{ext}') |
|
save_image(generated_image, path, nrow=num_images) |
|
|
|
@torch.no_grad() |
|
def calculate_fid(self, num_batches): |
|
from pytorch_fid import fid_score |
|
real_path = self.fid_dir / 'real' |
|
fake_path = self.fid_dir / 'fake' |
|
|
|
|
|
if not real_path.exists() or self.clear_fid_cache: |
|
rmtree(real_path, ignore_errors=True) |
|
os.makedirs(real_path) |
|
|
|
for batch_num in tqdm(range(num_batches), desc='calculating FID - saving reals'): |
|
real_batch = next(self.loader)["image"] |
|
for k, image in enumerate(real_batch.unbind(0)): |
|
ind = k + batch_num * self.batch_size |
|
save_image(image, real_path / f'{ind}.png') |
|
|
|
|
|
rmtree(fake_path, ignore_errors=True) |
|
os.makedirs(fake_path) |
|
|
|
self.GAN.eval() |
|
ext = self.image_extension |
|
|
|
latent_dim = self.GAN.latent_dim |
|
image_size = self.GAN.image_size |
|
|
|
for batch_num in tqdm(range(num_batches), desc='calculating FID - saving generated'): |
|
|
|
latents = torch.randn(self.batch_size, latent_dim, device=self.accelerator.device) |
|
|
|
|
|
generated_images = self.generate_(self.GAN.GE, latents) |
|
|
|
for j, image in enumerate(generated_images.unbind(0)): |
|
ind = j + batch_num * self.batch_size |
|
save_image(image, str(fake_path / f'{str(ind)}-ema.{ext}')) |
|
|
|
return fid_score.calculate_fid_given_paths([str(real_path), str(fake_path)], 256, latents.device, 2048) |
|
|
|
@torch.no_grad() |
|
def generate_(self, G, style, num_image_tiles=8): |
|
generated_images = evaluate_in_chunks(self.batch_size, G, style) |
|
return generated_images.clamp_(0., 1.) |
|
|
|
@torch.no_grad() |
|
def generate_interpolation(self, num=0, num_image_tiles=8, num_steps=100, save_frames=False): |
|
self.GAN.eval() |
|
ext = self.image_extension |
|
num_rows = num_image_tiles |
|
|
|
latent_dim = self.GAN.latent_dim |
|
image_size = self.GAN.image_size |
|
|
|
|
|
latents_low = torch.randn(num_rows ** 2, latent_dim, device=self.accelerator.device) |
|
latents_high = torch.randn(num_rows ** 2, latent_dim, device=self.accelerator.device) |
|
|
|
ratios = torch.linspace(0., 8., num_steps) |
|
|
|
frames = [] |
|
for ratio in tqdm(ratios): |
|
interp_latents = slerp(ratio, latents_low, latents_high) |
|
generated_images = self.generate_(self.GAN.GE, interp_latents) |
|
images_grid = torchvision.utils.make_grid(generated_images, nrow=num_rows) |
|
pil_image = transforms.ToPILImage()(images_grid.cpu()) |
|
|
|
if self.transparent: |
|
background = Image.new('RGBA', pil_image.size, (255, 255, 255)) |
|
pil_image = Image.alpha_composite(background, pil_image) |
|
|
|
frames.append(pil_image) |
|
|
|
frames[0].save(str(self.results_dir / self.name / f'{str(num)}.gif'), save_all=True, append_images=frames[1:], |
|
duration=80, loop=0, optimize=True) |
|
|
|
if save_frames: |
|
folder_path = (self.results_dir / self.name / f'{str(num)}') |
|
folder_path.mkdir(parents=True, exist_ok=True) |
|
for ind, frame in enumerate(frames): |
|
frame.save(str(folder_path / f'{str(ind)}.{ext}')) |
|
|
|
def print_log(self): |
|
data = [ |
|
('G', self.g_loss), |
|
('D', self.d_loss), |
|
('GP', self.last_gp_loss), |
|
('SS', self.last_recon_loss), |
|
('FID', self.last_fid) |
|
] |
|
|
|
data = [d for d in data if exists(d[1])] |
|
log = ' | '.join(map(lambda n: f'{n[0]}: {n[1]:.2f}', data)) |
|
print(log) |
|
|
|
if self.accelerator.is_local_main_process: |
|
log_dict = {v[0]: v[1] for v in data} |
|
if self.wandb: |
|
import wandb |
|
|
|
wandb.log(log_dict) |
|
|
|
def model_name(self, num): |
|
return str(self.models_dir / self.name / f'model_{num}.pt') |
|
|
|
def init_folders(self): |
|
(self.results_dir / self.name).mkdir(parents=True, exist_ok=True) |
|
(self.models_dir / self.name).mkdir(parents=True, exist_ok=True) |
|
|
|
def clear(self): |
|
rmtree(str(self.models_dir / self.name), True) |
|
rmtree(str(self.results_dir / self.name), True) |
|
rmtree(str(self.fid_dir), True) |
|
rmtree(str(self.config_path), True) |
|
self.init_folders() |
|
|
|
def save(self, num): |
|
save_data = { |
|
'GAN': self.GAN.state_dict(), |
|
} |
|
|
|
torch.save(save_data, self.model_name(num)) |
|
self.write_config() |
|
|
|
def load(self, num=-1): |
|
self.load_config() |
|
|
|
name = num |
|
if num == -1: |
|
checkpoints = self.get_checkpoints() |
|
|
|
if not exists(checkpoints): |
|
return |
|
|
|
name = checkpoints[-1] |
|
print(f'continuing from previous epoch - {name}') |
|
|
|
self.steps = name * self.save_every |
|
|
|
load_data = torch.load(self.model_name(name)) |
|
|
|
try: |
|
self.GAN.load_state_dict(load_data['GAN']) |
|
except Exception as e: |
|
print( |
|
'unable to load save model. please try downgrading the package to the version specified by the saved model') |
|
raise e |
|
|
|
def get_checkpoints(self): |
|
file_paths = [p for p in Path(self.models_dir / self.name).glob('model_*.pt')] |
|
saved_nums = sorted(map(lambda x: int(x.stem.split('_')[1]), file_paths)) |
|
|
|
if len(saved_nums) == 0: |
|
return None |
|
|
|
return saved_nums |
|
|