RetinaGAN / models /cstylegan.py
farrell236's picture
Upload 37 files
2aa6515
raw
history blame
19.6 kB
# This file is based on the StyleGAN by Cheong et. al
# https://keras.io/examples/generative/stylegan/
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from tensorflow_addons.layers import InstanceNormalization
def log2(x):
return int(np.log2(x))
# we use different batch size for different resolution, so larger image size
# could fit into GPU memory. The keys is image resolution in log2
batch_sizes = {2: 16, 3: 16, 4: 16, 5: 16, 6: 16, 7: 8, 8: 4, 9: 2, 10: 1}
# We adjust the train step accordingly
train_step_ratio = {k: batch_sizes[2] / v for k, v in batch_sizes.items()}
def fade_in(alpha, a, b):
return alpha * a + (1.0 - alpha) * b
def wasserstein_loss(y_true, y_pred):
return -tf.reduce_mean(y_true * y_pred)
def pixel_norm(x, epsilon=1e-8):
return x / tf.math.sqrt(tf.reduce_mean(x ** 2, axis=-1, keepdims=True) + epsilon)
def minibatch_std(input_tensor, epsilon=1e-8):
n, h, w, c = tf.shape(input_tensor)
group_size = tf.minimum(4, n)
x = tf.reshape(input_tensor, [group_size, -1, h, w, c])
group_mean, group_var = tf.nn.moments(x, axes=(0), keepdims=False)
group_std = tf.sqrt(group_var + epsilon)
avg_std = tf.reduce_mean(group_std, axis=[1, 2, 3], keepdims=True)
x = tf.tile(avg_std, [group_size, h, w, 1])
return tf.concat([input_tensor, x], axis=-1)
class EqualizedConv(layers.Layer):
def __init__(self, out_channels, kernel=3, gain=2, **kwargs):
super(EqualizedConv, self).__init__(**kwargs)
self.kernel = kernel
self.out_channels = out_channels
self.gain = gain
self.pad = kernel != 1
def build(self, input_shape):
self.in_channels = input_shape[-1]
initializer = keras.initializers.RandomNormal(mean=0.0, stddev=1.0)
self.w = self.add_weight(
shape=[self.kernel, self.kernel, self.in_channels, self.out_channels],
initializer=initializer,
trainable=True,
name="kernel",
)
self.b = self.add_weight(
shape=(self.out_channels,), initializer="zeros", trainable=True, name="bias"
)
fan_in = self.kernel * self.kernel * self.in_channels
self.scale = tf.sqrt(self.gain / fan_in)
def call(self, inputs):
if self.pad:
x = tf.pad(inputs, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="REFLECT")
else:
x = inputs
output = (
tf.nn.conv2d(x, self.scale * self.w, strides=1, padding="VALID") + self.b
)
return output
class EqualizedDense(layers.Layer):
def __init__(self, units, gain=2, learning_rate_multiplier=1, **kwargs):
super(EqualizedDense, self).__init__(**kwargs)
self.units = units
self.gain = gain
self.learning_rate_multiplier = learning_rate_multiplier
def build(self, input_shape):
self.in_channels = input_shape[-1]
initializer = keras.initializers.RandomNormal(
mean=0.0, stddev=1.0 / self.learning_rate_multiplier
)
self.w = self.add_weight(
shape=[self.in_channels, self.units],
initializer=initializer,
trainable=True,
name="kernel",
)
self.b = self.add_weight(
shape=(self.units,), initializer="zeros", trainable=True, name="bias"
)
fan_in = self.in_channels
self.scale = tf.sqrt(self.gain / fan_in)
def call(self, inputs):
output = tf.add(tf.matmul(inputs, self.scale * self.w), self.b)
return output * self.learning_rate_multiplier
class AddNoise(layers.Layer):
def build(self, input_shape):
n, h, w, c = input_shape[0]
initializer = keras.initializers.RandomNormal(mean=0.0, stddev=1.0)
self.b = self.add_weight(
shape=[1, 1, 1, c], initializer=initializer, trainable=True, name="kernel"
)
def call(self, inputs):
x, noise = inputs
output = x + self.b * noise
return output
class AdaIN(layers.Layer):
def __init__(self, gain=1, **kwargs):
super(AdaIN, self).__init__(**kwargs)
self.gain = gain
def build(self, input_shapes):
x_shape = input_shapes[0]
w_shape = input_shapes[1]
self.w_channels = w_shape[-1]
self.x_channels = x_shape[-1]
self.dense_1 = EqualizedDense(self.x_channels, gain=1)
self.dense_2 = EqualizedDense(self.x_channels, gain=1)
def call(self, inputs):
x, w = inputs
ys = tf.reshape(self.dense_1(w), (-1, 1, 1, self.x_channels))
yb = tf.reshape(self.dense_2(w), (-1, 1, 1, self.x_channels))
return ys * x + yb
def Mapping(num_stages, input_shape=512):
z = layers.Input(shape=(input_shape,))
w = pixel_norm(z)
class_embedding = layers.Input(shape=512)
for i in range(8):
w = EqualizedDense(512, learning_rate_multiplier=0.01)(w)
w = w + class_embedding
w = layers.LeakyReLU(0.2)(w)
w = tf.tile(tf.expand_dims(w, 1), (1, num_stages, 1))
return keras.Model([z, class_embedding], w, name="mapping")
class Generator:
def __init__(self, start_res_log2, target_res_log2):
self.start_res_log2 = start_res_log2
self.target_res_log2 = target_res_log2
self.num_stages = target_res_log2 - start_res_log2 + 1
# list of generator blocks at increasing resolution
self.g_blocks = []
# list of layers to convert g_block activation to RGB
self.to_rgb = []
# list of noise input of different resolutions into g_blocks
self.noise_inputs = []
# filter size to use at each stage, keys are log2(resolution)
self.filter_nums = {
0: 512,
1: 512,
2: 512, # 4x4
3: 512, # 8x8
4: 512, # 16x16
5: 512, # 32x32
6: 256, # 64x64
7: 128, # 128x128
8: 64, # 256x256
9: 32, # 512x512
10: 16,
} # 1024x1024
start_res = 2 ** start_res_log2
self.input_shape = (start_res, start_res, self.filter_nums[start_res_log2])
self.g_input = layers.Input(self.input_shape, name="generator_input")
for i in range(start_res_log2, target_res_log2 + 1):
filter_num = self.filter_nums[i]
res = 2 ** i
self.noise_inputs.append(
layers.Input(shape=(res, res, 1), name=f"noise_{res}x{res}")
)
to_rgb = Sequential(
[
layers.InputLayer(input_shape=(res, res, filter_num)),
EqualizedConv(7, 1, gain=1), # CHANGE NO OF CHANNELS
],
name=f"to_rgb_{res}x{res}",
)
self.to_rgb.append(to_rgb)
is_base = i == self.start_res_log2
if is_base:
input_shape = (res, res, self.filter_nums[i - 1])
else:
input_shape = (2 ** (i - 1), 2 ** (i - 1), self.filter_nums[i - 1])
g_block = self.build_block(
filter_num, res=res, input_shape=input_shape, is_base=is_base
)
self.g_blocks.append(g_block)
def build_block(self, filter_num, res, input_shape, is_base):
input_tensor = layers.Input(shape=input_shape, name=f"g_{res}")
noise = layers.Input(shape=(res, res, 1), name=f"noise_{res}")
w = layers.Input(shape=512)
x = input_tensor
if not is_base:
x = layers.UpSampling2D((2, 2))(x)
x = EqualizedConv(filter_num, 3)(x)
x = AddNoise()([x, noise])
x = layers.LeakyReLU(0.2)(x)
x = InstanceNormalization()(x)
x = AdaIN()([x, w])
x = EqualizedConv(filter_num, 3)(x)
x = AddNoise()([x, noise])
x = layers.LeakyReLU(0.2)(x)
x = InstanceNormalization()(x)
x = AdaIN()([x, w])
return keras.Model([input_tensor, w, noise], x, name=f"genblock_{res}x{res}")
def grow(self, res_log2):
res = 2 ** res_log2
num_stages = res_log2 - self.start_res_log2 + 1
w = layers.Input(shape=(self.num_stages, 512), name="w")
alpha = layers.Input(shape=(1), name="g_alpha")
x = self.g_blocks[0]([self.g_input, w[:, 0], self.noise_inputs[0]])
if num_stages == 1:
rgb = self.to_rgb[0](x)
else:
for i in range(1, num_stages - 1):
x = self.g_blocks[i]([x, w[:, i], self.noise_inputs[i]])
old_rgb = self.to_rgb[num_stages - 2](x)
old_rgb = layers.UpSampling2D((2, 2))(old_rgb)
i = num_stages - 1
x = self.g_blocks[i]([x, w[:, i], self.noise_inputs[i]])
new_rgb = self.to_rgb[i](x)
rgb = fade_in(alpha[0], new_rgb, old_rgb)
return keras.Model(
[self.g_input, w, self.noise_inputs, alpha],
rgb,
name=f"generator_{res}_x_{res}",
)
class Discriminator:
def __init__(self, start_res_log2, target_res_log2):
self.start_res_log2 = start_res_log2
self.target_res_log2 = target_res_log2
self.num_stages = target_res_log2 - start_res_log2 + 1
# filter size to use at each stage, keys are log2(resolution)
self.filter_nums = {
0: 512,
1: 512,
2: 512, # 4x4
3: 512, # 8x8
4: 512, # 16x16
5: 512, # 32x32
6: 256, # 64x64
7: 128, # 128x128
8: 64, # 256x256
9: 32, # 512x512
10: 16,
} # 1024x1024
# list of discriminator blocks at increasing resolution
self.d_blocks = []
# list of layers to convert RGB into activation for d_blocks inputs
self.from_rgb = []
# Conditional embedding
# self.embedding = layers.Embedding(5, 256)
for res_log2 in range(self.start_res_log2, self.target_res_log2 + 1):
res = 2 ** res_log2
filter_num = self.filter_nums[res_log2]
from_rgb = Sequential(
[
layers.InputLayer(
input_shape=(res, res, 7), name=f"from_rgb_input_{res}" # CHANGE NO OF CHANNELS
),
EqualizedConv(filter_num, 1),
layers.LeakyReLU(0.2),
],
name=f"from_rgb_{res}",
)
self.from_rgb.append(from_rgb)
input_shape = (res, res, filter_num)
if len(self.d_blocks) == 0:
d_block = self.build_base(filter_num, res)
else:
d_block = self.build_block(
filter_num, self.filter_nums[res_log2 - 1], res
)
self.d_blocks.append(d_block)
def build_base(self, filter_num, res):
input_tensor = layers.Input(shape=(res, res, filter_num), name=f"d_{res}")
x = minibatch_std(input_tensor)
x = EqualizedConv(filter_num, 3)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Flatten()(x)
x = EqualizedDense(filter_num)(x)
x = layers.LeakyReLU(0.2)(x)
x = EqualizedDense(1)(x)
return keras.Model(input_tensor, x, name=f"d_{res}")
def build_block(self, filter_num_1, filter_num_2, res):
input_tensor = layers.Input(shape=(res, res, filter_num_1), name=f"d_{res}")
x = EqualizedConv(filter_num_1, 3)(input_tensor)
x = layers.LeakyReLU(0.2)(x)
x = EqualizedConv(filter_num_2)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.AveragePooling2D((2, 2))(x)
return keras.Model(input_tensor, x, name=f"d_{res}")
def grow(self, res_log2):
res = 2 ** res_log2
idx = res_log2 - self.start_res_log2
alpha = layers.Input(shape=(1), name="d_alpha")
input_image = layers.Input(shape=(res, res, 7), name="input_image") # CHANGE NO OF CHANNELS
class_embedding = layers.Input(shape=512, name="class_embedding")
x = self.from_rgb[idx](input_image)
x = AdaIN()([x, class_embedding])
x = self.d_blocks[idx](x)
if idx > 0:
idx -= 1
downsized_image = layers.AveragePooling2D((2, 2))(input_image)
y = self.from_rgb[idx](downsized_image)
x = fade_in(alpha[0], x, y)
for i in range(idx, -1, -1):
x = AdaIN()([x, class_embedding])
x = self.d_blocks[i](x)
return keras.Model([input_image, class_embedding, alpha], x, name=f"discriminator_{res}_x_{res}")
class cStyleGAN(tf.keras.Model):
def __init__(self, z_dim=512, target_res=64, start_res=4):
super(cStyleGAN, self).__init__()
self.z_dim = z_dim
self.target_res_log2 = log2(target_res)
self.start_res_log2 = log2(start_res)
self.current_res_log2 = self.target_res_log2
self.num_stages = self.target_res_log2 - self.start_res_log2 + 1
self.alpha = tf.Variable(1.0, dtype=tf.float32, trainable=False, name="alpha")
self.mapping = Mapping(num_stages=self.num_stages)
self.embedding = layers.Embedding(5, 512)
self.d_builder = Discriminator(self.start_res_log2, self.target_res_log2)
self.g_builder = Generator(self.start_res_log2, self.target_res_log2)
self.g_input_shape = self.g_builder.input_shape
self.phase = None
self.train_step_counter = tf.Variable(0, dtype=tf.int32, trainable=False)
self.loss_weights = {"gradient_penalty": 10, "drift": 0.001}
def grow_model(self, res):
tf.keras.backend.clear_session()
res_log2 = log2(res)
self.generator = self.g_builder.grow(res_log2)
self.discriminator = self.d_builder.grow(res_log2)
self.current_res_log2 = res_log2
print(f"\nModel resolution:{res}x{res}")
def compile(
self, steps_per_epoch, phase, res, d_optimizer, g_optimizer, *args, **kwargs
):
self.loss_weights = kwargs.pop("loss_weights", self.loss_weights)
self.steps_per_epoch = steps_per_epoch
if res != 2 ** self.current_res_log2:
self.grow_model(res)
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.train_step_counter.assign(0)
self.phase = phase
self.d_loss_metric = keras.metrics.Mean(name="d_loss")
self.g_loss_metric = keras.metrics.Mean(name="g_loss")
super(cStyleGAN, self).compile(*args, **kwargs)
@property
def metrics(self):
return [self.d_loss_metric, self.g_loss_metric]
def generate_noise(self, batch_size):
noise = [
tf.random.normal((batch_size, 2 ** res, 2 ** res, 1))
for res in range(self.start_res_log2, self.target_res_log2 + 1)
]
return noise
def gradient_loss(self, grad):
loss = tf.square(grad)
loss = tf.reduce_sum(loss, axis=tf.range(1, tf.size(tf.shape(loss))))
loss = tf.sqrt(loss)
loss = tf.reduce_mean(tf.square(loss - 1))
return loss
def train_step(self, data_tuple):
real_images, class_label = data_tuple
self.train_step_counter.assign_add(1)
if self.phase == "TRANSITION":
self.alpha.assign(
tf.cast(self.train_step_counter / self.steps_per_epoch, tf.float32)
)
elif self.phase == "STABLE":
self.alpha.assign(1.0)
else:
raise NotImplementedError
alpha = tf.expand_dims(self.alpha, 0)
batch_size = tf.shape(real_images)[0]
real_labels = tf.ones(batch_size)
fake_labels = -tf.ones(batch_size)
z = tf.random.normal((batch_size, self.z_dim))
const_input = tf.ones(tuple([batch_size] + list(self.g_input_shape)))
noise = self.generate_noise(batch_size)
# generator
with tf.GradientTape() as g_tape:
class_embedding = self.embedding(class_label)
w = self.mapping([z, class_embedding])
fake_images = self.generator([const_input, w, noise, alpha])
pred_fake = self.discriminator([fake_images, class_embedding, alpha])
g_loss = wasserstein_loss(real_labels, pred_fake)
trainable_weights = (
self.embedding.trainable_weights + self.mapping.trainable_weights + self.generator.trainable_weights
)
gradients = g_tape.gradient(g_loss, trainable_weights)
self.g_optimizer.apply_gradients(zip(gradients, trainable_weights))
# discriminator
with tf.GradientTape() as gradient_tape, tf.GradientTape() as total_tape:
# class_embedding = self.embedding(class_label)
# forward pass
pred_fake = self.discriminator([fake_images, class_embedding, alpha])
pred_real = self.discriminator([real_images, class_embedding, alpha])
epsilon = tf.random.uniform((batch_size, 1, 1, 1))
interpolates = epsilon * real_images + (1 - epsilon) * fake_images
gradient_tape.watch(interpolates)
pred_fake_grad = self.discriminator([interpolates, class_embedding, alpha])
# calculate losses
loss_fake = wasserstein_loss(fake_labels, pred_fake)
loss_real = wasserstein_loss(real_labels, pred_real)
loss_fake_grad = wasserstein_loss(fake_labels, pred_fake_grad)
# gradient penalty
gradients_fake = gradient_tape.gradient(loss_fake_grad, [interpolates])
gradient_penalty = self.loss_weights[
"gradient_penalty"
] * self.gradient_loss(gradients_fake)
# drift loss
all_pred = tf.concat([pred_fake, pred_real], axis=0)
drift_loss = self.loss_weights["drift"] * tf.reduce_mean(all_pred ** 2)
d_loss = loss_fake + loss_real + gradient_penalty + drift_loss
gradients = total_tape.gradient(
d_loss, self.discriminator.trainable_weights
)
self.d_optimizer.apply_gradients(
zip(gradients, self.discriminator.trainable_weights)
)
# Update metrics
self.d_loss_metric.update_state(d_loss)
self.g_loss_metric.update_state(g_loss)
return {
"d_loss": self.d_loss_metric.result(),
"g_loss": self.g_loss_metric.result(),
}
def call(self, inputs: dict()):
style_code = inputs.get("style_code", None)
z = inputs.get("z", None)
noise = inputs.get("noise", None)
class_label = inputs.get("class_label", 0)
batch_size = inputs.get("batch_size", 1)
alpha = inputs.get("alpha", 1.0)
alpha = tf.expand_dims(alpha, 0)
class_embedding = self.embedding(class_label)
if style_code is None:
if z is None:
z = tf.random.normal((batch_size, self.z_dim))
style_code = self.mapping([z, class_embedding])
if noise is None:
noise = self.generate_noise(batch_size)
# self.alpha.assign(alpha)
const_input = tf.ones(tuple([batch_size] + list(self.g_input_shape)))
images = self.generator([const_input, style_code, noise, alpha])
# images = np.clip((images * 0.5 + 0.5) * 255, 0, 255).astype(np.uint8)
images = tf.clip_by_value((images * 0.5 + 0.5) * 255, 0, 255)
return images