diff --git "a/networks/layers.py" "b/networks/layers.py" --- "a/networks/layers.py" +++ "b/networks/layers.py" @@ -1,17 +1,5 @@ import tensorflow as tf -import numpy as np -import scipy.signal as sps -import scipy.special as spspec - -import tensorflow.keras.backend as K -import math -from tensorflow.python.keras.utils import conv_utils -from tensorflow.keras.layers import Layer, InputSpec, Conv2D, LeakyReLU, Dense, BatchNormalization, Input, Concatenate -from tensorflow.keras.layers import Conv2DTranspose, ReLU, Activation, UpSampling2D, Add, Reshape, Multiply -from tensorflow.keras.layers import AveragePooling2D, LayerNormalization, GlobalAveragePooling2D, MaxPooling2D, Flatten -from tensorflow.keras import initializers, constraints, regularizers -from tensorflow.keras.models import Model -from tensorflow_addons.layers import InstanceNormalization +from tensorflow.keras.layers import Layer, Dense def sin_activation(x, omega=30): @@ -47,548 +35,6 @@ class AdaIN(Layer): return dict(list(base_config.items()) + list(config.items())) -class Conv2DMod(Layer): - - def __init__(self, - filters, - kernel_size, - strides=1, - padding='valid', - dilation_rate=1, - kernel_initializer='glorot_uniform', - kernel_regularizer=None, - activity_regularizer=None, - kernel_constraint=None, - demod=True, - **kwargs): - super(Conv2DMod, self).__init__(**kwargs) - self.filters = filters - self.rank = 2 - self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size') - self.strides = conv_utils.normalize_tuple(strides, 2, 'strides') - self.padding = conv_utils.normalize_padding(padding) - self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, 2, 'dilation_rate') - self.kernel_initializer = initializers.get(kernel_initializer) - self.kernel_regularizer = regularizers.get(kernel_regularizer) - self.activity_regularizer = regularizers.get(activity_regularizer) - self.kernel_constraint = constraints.get(kernel_constraint) - self.demod = demod - self.input_spec = [InputSpec(ndim = 4), - InputSpec(ndim = 2)] - - def build(self, input_shape): - channel_axis = -1 - if input_shape[0][channel_axis] is None: - raise ValueError('The channel dimension of the inputs ' - 'should be defined. Found `None`.') - input_dim = input_shape[0][channel_axis] - kernel_shape = self.kernel_size + (input_dim, self.filters) - - if input_shape[1][-1] != input_dim: - raise ValueError('The last dimension of modulation input should be equal to input dimension.') - - self.kernel = self.add_weight(shape=kernel_shape, - initializer=self.kernel_initializer, - name='kernel', - regularizer=self.kernel_regularizer, - constraint=self.kernel_constraint) - - # Set input spec. - self.input_spec = [InputSpec(ndim=4, axes={channel_axis: input_dim}), - InputSpec(ndim=2)] - self.built = True - - def call(self, inputs): - - #To channels last - x = tf.transpose(inputs[0], [0, 3, 1, 2]) - - #Get weight and bias modulations - #Make sure w's shape is compatible with self.kernel - w = K.expand_dims(K.expand_dims(K.expand_dims(inputs[1], axis = 1), axis = 1), axis = -1) - - #Add minibatch layer to weights - wo = K.expand_dims(self.kernel, axis = 0) - - #Modulate - weights = wo * (w+1) - - #Demodulate - if self.demod: - d = K.sqrt(K.sum(K.square(weights), axis=[1,2,3], keepdims = True) + 1e-8) - weights = weights / d - - #Reshape/scale input - x = tf.reshape(x, [1, -1, x.shape[2], x.shape[3]]) # Fused => reshape minibatch to convolution groups. - w = tf.reshape(tf.transpose(weights, [1, 2, 3, 0, 4]), [weights.shape[1], weights.shape[2], weights.shape[3], -1]) - - x = tf.nn.conv2d(x, w, - strides=self.strides, - padding="SAME", - data_format="NCHW") - - # Reshape/scale output. - x = tf.reshape(x, [-1, self.filters, x.shape[2], x.shape[3]]) # Fused => reshape convolution groups back to minibatch. - x = tf.transpose(x, [0, 2, 3, 1]) - - return x - - def compute_output_shape(self, input_shape): - space = input_shape[0][1:-1] - new_space = [] - for i in range(len(space)): - new_dim = conv_utils.conv_output_length( - space[i], - self.kernel_size[i], - padding=self.padding, - stride=self.strides[i], - dilation=self.dilation_rate[i]) - new_space.append(new_dim) - - return (input_shape[0],) + tuple(new_space) + (self.filters,) - - def get_config(self): - config = { - 'filters': self.filters, - 'kernel_size': self.kernel_size, - 'strides': self.strides, - 'padding': self.padding, - 'dilation_rate': self.dilation_rate, - 'kernel_initializer': initializers.serialize(self.kernel_initializer), - 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), - 'activity_regularizer': - regularizers.serialize(self.activity_regularizer), - 'kernel_constraint': constraints.serialize(self.kernel_constraint), - 'demod': self.demod - } - base_config = super(Conv2DMod, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - -class CreatePatches(tf.keras.layers.Layer ): - - def __init__( self , patch_size): - super( CreatePatches , self).__init__() - self.patch_size = patch_size - - def call(self, inputs): - patches = [] - # For square images only ( as inputs.shape[ 1 ] = inputs.shape[ 2 ] ) - input_image_size = inputs.shape[ 1 ] - for i in range( 0 , input_image_size , self.patch_size ): - for j in range( 0 , input_image_size , self.patch_size ): - patches.append( inputs[ : , i : i + self.patch_size , j : j + self.patch_size , : ] ) - return patches - - def get_config(self): - config = {'patch_size': self.patch_size, - } - base_config = super(CreatePatches, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - -class SelfAttention(tf.keras.layers.Layer ): - - def __init__( self , alpha, filters=128): - super(SelfAttention , self).__init__() - self.alpha = alpha - self.filters = filters - - self.f = Conv2D(filters, 1, 1) - self.g = Conv2D(filters, 1, 1) - self.s = Conv2D(filters, 1, 1) - - def call(self, inputs): - - f_map = self.f(inputs) - f_map = tf.image.transpose(f_map) - - g_map = self.g(inputs) - - s_map = self.s(inputs) - - att = f_map * g_map - - att = att / self.alpha - - return tf.keras.activations.softmax(att + s_map, axis=0) - - def get_config(self): - config = {'alpha': self.alpha, - } - base_config = super(SelfAttention, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - -class Sampling(Layer): - """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.""" - - def call(self, inputs): - z_mean, z_log_var = inputs - batch = tf.shape(z_mean)[0] - dim = tf.shape(z_mean)[1] - epsilon = tf.keras.backend.random_normal(shape=(batch, dim)) - return z_mean + tf.exp(0.5 * z_log_var) * epsilon - - def get_config(self): - config = { - } - base_config = super(Sampling, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - -class ArcMarginPenaltyLogists(tf.keras.layers.Layer): - """ArcMarginPenaltyLogists""" - def __init__(self, num_classes, margin=0.7, logist_scale=64, **kwargs): - super(ArcMarginPenaltyLogists, self).__init__(**kwargs) - self.num_classes = num_classes - self.margin = margin - self.logist_scale = logist_scale - - def build(self, input_shape): - self.w = self.add_variable( - "weights", shape=[int(input_shape[-1]), self.num_classes]) - self.cos_m = tf.identity(math.cos(self.margin), name='cos_m') - self.sin_m = tf.identity(math.sin(self.margin), name='sin_m') - self.th = tf.identity(math.cos(math.pi - self.margin), name='th') - self.mm = tf.multiply(self.sin_m, self.margin, name='mm') - - def call(self, embds, labels): - normed_embds = tf.nn.l2_normalize(embds, axis=1, name='normed_embd') - normed_w = tf.nn.l2_normalize(self.w, axis=0, name='normed_weights') - - cos_t = tf.matmul(normed_embds, normed_w, name='cos_t') - sin_t = tf.sqrt(1. - cos_t ** 2, name='sin_t') - - cos_mt = tf.subtract( - cos_t * self.cos_m, sin_t * self.sin_m, name='cos_mt') - - cos_mt = tf.where(cos_t > self.th, cos_mt, cos_t - self.mm) - - mask = tf.one_hot(tf.cast(labels, tf.int32), depth=self.num_classes, - name='one_hot_mask') - - logists = tf.where(mask == 1., cos_mt, cos_t) - logists = tf.multiply(logists, self.logist_scale, 'arcface_logist') - - return logists - - def get_config(self): - config = {'num_classes': self.num_classes, - 'margin': self.margin, - 'logist_scale': self.logist_scale - } - base_config = super(ArcMarginPenaltyLogists, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - -class KLLossLayer(tf.keras.layers.Layer): - """ArcMarginPenaltyLogists""" - def __init__(self, beta=1.5, **kwargs): - super(KLLossLayer, self).__init__(**kwargs) - self.beta = beta - - def call(self, inputs): - z_mean, z_log_var = inputs - - kl_loss = tf.reduce_mean(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)) - kl_loss = -0.5 * kl_loss * self.beta - - self.add_loss(kl_loss * 0) - self.add_metric(kl_loss, 'kl_loss') - - return inputs - - def get_config(self): - config = { - 'beta': self.beta - } - base_config = super(KLLossLayer, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - -class ReflectionPadding2D(Layer): - def __init__(self, padding=(1, 1), **kwargs): - self.padding = tuple(padding) - self.input_spec = [InputSpec(ndim=4)] - super(ReflectionPadding2D, self).__init__(**kwargs) - - def compute_output_shape(self, s): - return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3]) - - def call(self, x, mask=None): - w_pad,h_pad = self.padding - return tf.pad(x, [[0,0], [h_pad,h_pad], [w_pad,w_pad], [0,0] ], 'REFLECT') - - def get_config(self): - config = { - 'padding': self.padding, - } - base_config = super(ReflectionPadding2D, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - -class ResBlock(Layer): - - def __init__(self, fil, **kwargs): - super(ResBlock, self).__init__(**kwargs) - self.fil = fil - - self.conv_0 = Conv2D(kernel_size=3, filters=fil, strides=1) - self.conv_1 = Conv2D(kernel_size=3, filters=fil, strides=1) - - self.res = Conv2D(kernel_size=1, filters=1, strides=1) - - self.lrelu = LeakyReLU(0.2) - self.padding = ReflectionPadding2D(padding=(1, 1)) - - def call(self, inputs): - res = self.res(inputs) - - x = self.padding(inputs) - x = self.conv_0(x) - x = self.lrelu(x) - - x = self.padding(x) - x = self.conv_1(x) - x = self.lrelu(x) - - out = x + res - - return out - - def get_config(self): - config = { - 'fil': self.fil - } - base_config = super(ResBlock, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - -class SubpixelConv2D(Layer): - """ Subpixel Conv2D Layer - upsampling a layer from (h, w, c) to (h*r, w*r, c/(r*r)), - where r is the scaling factor, default to 4 - # Arguments - upsampling_factor: the scaling factor - # Input shape - Arbitrary. Use the keyword argument `input_shape` - (tuple of integers, does not include the samples axis) - when using this layer as the first layer in a model. - # Output shape - the second and the third dimension increased by a factor of - `upsampling_factor`; the last layer decreased by a factor of - `upsampling_factor^2`. - # References - Real-Time Single Image and Video Super-Resolution Using an Efficient - Sub-Pixel Convolutional Neural Network Shi et Al. https://arxiv.org/abs/1609.05158 - """ - - def __init__(self, upsampling_factor=4, **kwargs): - super(SubpixelConv2D, self).__init__(**kwargs) - self.upsampling_factor = upsampling_factor - - def build(self, input_shape): - last_dim = input_shape[-1] - factor = self.upsampling_factor * self.upsampling_factor - if last_dim % (factor) != 0: - raise ValueError('Channel ' + str(last_dim) + ' should be of ' - 'integer times of upsampling_factor^2: ' + - str(factor) + '.') - - def call(self, inputs, **kwargs): - return tf.nn.depth_to_space( inputs, self.upsampling_factor ) - - def get_config(self): - config = { 'upsampling_factor': self.upsampling_factor, } - base_config = super(SubpixelConv2D, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - def compute_output_shape(self, input_shape): - factor = self.upsampling_factor * self.upsampling_factor - input_shape_1 = None - if input_shape[1] is not None: - input_shape_1 = input_shape[1] * self.upsampling_factor - input_shape_2 = None - if input_shape[2] is not None: - input_shape_2 = input_shape[2] * self.upsampling_factor - dims = [ input_shape[0], - input_shape_1, - input_shape_2, - int(input_shape[3]/factor) - ] - return tuple( dims ) - - -def id_mod_res(inputs, c): - feature_map, z_id = inputs - - x = Conv2D(c, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(feature_map) - - x = AdaIN()([x, z_id]) - - x = ReLU()(x) - - x = Conv2D(c, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) - - x = AdaIN()([x, z_id]) - - out = Add()([x, feature_map]) - - return out - - -def id_mod_res_v2(inputs, c): - feature_map, z_id = inputs - - affine = Dense(feature_map.shape[-1])(z_id) - x = Conv2DMod(c, kernel_size=3, padding='same', - kernel_regularizer=tf.keras.regularizers.l2(0.0001))([feature_map, affine]) - - x = ReLU()(x) - - affine = Dense(x.shape[-1])(z_id) - x = Conv2DMod(c, kernel_size=3, padding='same', - kernel_regularizer=tf.keras.regularizers.l2(0.0001))([x, affine]) - out = Add()([x, feature_map]) - - x = ReLU()(x) - - return out - - -def simswap(im_size, filter_scale=1, deep=True): - inputs = Input(shape=(im_size, im_size, 3)) - z_id = Input(shape=(512,)) - - x = ReflectionPadding2D(padding=(3, 3))(inputs) - x = Conv2D(filters=64 // filter_scale, kernel_size=7, padding='valid', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) # 112 - x = BatchNormalization()(x) - x = Activation(tf.keras.activations.relu)(x) - - x = Conv2D(filters=64 // filter_scale, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) # 56 - x = BatchNormalization()(x) - x = Activation(tf.keras.activations.relu)(x) - - x = Conv2D(filters=256 // filter_scale, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) # 28 - x = BatchNormalization()(x) - x = Activation(tf.keras.activations.relu)(x) - - x = Conv2D(filters=512 // filter_scale, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) # 14 - x = BatchNormalization()(x) - x = Activation(tf.keras.activations.relu)(x) # 14 - - if deep: - x = Conv2D(filters=512 // filter_scale, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) # 7 - x = BatchNormalization()(x) - x = Activation(tf.keras.activations.relu)(x) - - x = id_mod_res([x, z_id], 512 // filter_scale) - - x = id_mod_res([x, z_id], 512 // filter_scale) - - x = id_mod_res([x, z_id], 512 // filter_scale) - - x = id_mod_res([x, z_id], 512 // filter_scale) - - if deep: - x = SubpixelConv2D(upsampling_factor=2)(x) - x = Conv2D(filters=512 // filter_scale, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.001))(x) - x = BatchNormalization()(x) - x = Activation(tf.keras.activations.relu)(x) - - x = SubpixelConv2D(upsampling_factor=2)(x) - x = Conv2D(filters=256 // filter_scale, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.001))(x) - x = BatchNormalization()(x) - x = Activation(tf.keras.activations.relu)(x) - - x = SubpixelConv2D(upsampling_factor=2)(x) - x = Conv2D(filters=128 // filter_scale, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.001))(x) - x = BatchNormalization()(x) - x = Activation(tf.keras.activations.relu)(x) # 56 - - x = SubpixelConv2D(upsampling_factor=2)(x) - x = Conv2D(filters=64 // filter_scale, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.001))(x) - x = BatchNormalization()(x) - x = Activation(tf.keras.activations.relu)(x) # 112 - - x = ReflectionPadding2D(padding=(3, 3))(x) - out = Conv2D(filters=3, kernel_size=7, padding='valid')(x) # 112 - - model = Model([inputs, z_id], out) - model.summary() - - return model - - -def simswap_v2(deep=True): - inputs = Input(shape=(224, 224, 3)) - z_id = Input(shape=(512,)) - - x = ReflectionPadding2D(padding=(3, 3))(inputs) - x = Conv2D(filters=64, kernel_size=7, padding='valid', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) # 112 - x = BatchNormalization()(x) - x = Activation(tf.keras.activations.relu)(x) - - x = Conv2D(filters=64, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) # 56 - x = BatchNormalization()(x) - x = Activation(tf.keras.activations.relu)(x) - - x = Conv2D(filters=256, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) # 28 - x = BatchNormalization()(x) - x = Activation(tf.keras.activations.relu)(x) - - x = Conv2D(filters=512, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) # 14 - x = BatchNormalization()(x) - x = Activation(tf.keras.activations.relu)(x) # 14 - - if deep: - x = Conv2D(filters=512, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) # 7 - x = BatchNormalization()(x) - x = Activation(tf.keras.activations.relu)(x) - - x = id_mod_res_v2([x, z_id], 512) - - x = id_mod_res_v2([x, z_id], 512) - - x = id_mod_res_v2([x, z_id], 512) - - x = id_mod_res_v2([x, z_id], 512) - - x = id_mod_res_v2([x, z_id], 512) - - x = id_mod_res_v2([x, z_id], 512) - - if deep: - x = UpSampling2D(interpolation='bilinear')(x) - x = Conv2D(filters=512, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) - x = BatchNormalization()(x) - x = Activation(tf.keras.activations.relu)(x) - - x = UpSampling2D(interpolation='bilinear')(x) - x = Conv2D(filters=256, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) - x = BatchNormalization()(x) - x = Activation(tf.keras.activations.relu)(x) - - x = UpSampling2D(interpolation='bilinear')(x) - x = Conv2D(filters=128, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) - x = BatchNormalization()(x) - x = Activation(tf.keras.activations.relu)(x) # 56 - - x = UpSampling2D(interpolation='bilinear')(x) - x = Conv2D(filters=64, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) - x = BatchNormalization()(x) - x = Activation(tf.keras.activations.relu)(x) # 112 - - x = ReflectionPadding2D(padding=(3, 3))(x) - out = Conv2D(filters=3, kernel_size=7, padding='valid')(x) # 112 - out = Activation('sigmoid')(out) - - model = Model([inputs, z_id], out) - model.summary() - - return model - - class AdaptiveAttention(Layer): def __init__(self, **kwargs): @@ -601,1973 +47,3 @@ class AdaptiveAttention(Layer): def get_config(self): base_config = super(AdaptiveAttention, self).get_config() return base_config - - -def aad_block(inputs, c_out): - h, z_att, z_id = inputs - - h_norm = BatchNormalization()(h) - h = Conv2D(filters=c_out, kernel_size=1, kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(h_norm) - - m = Activation('sigmoid')(h) - - z_att_gamma = Conv2D(filters=c_out, - kernel_size=1, - kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(z_att) - - z_att_beta = Conv2D(filters=c_out, - kernel_size=1, - kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(z_att) - - a = Multiply()([h_norm, z_att_gamma]) - a = Add()([a, z_att_beta]) - - z_id_gamma = Dense(h_norm.shape[-1], - kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(z_id) - z_id_gamma = Reshape(target_shape=(1, 1, h_norm.shape[-1]))(z_id_gamma) - - z_id_beta = Dense(h_norm.shape[-1], - kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(z_id) - z_id_beta = Reshape(target_shape=(1, 1, h_norm.shape[-1]))(z_id_beta) - - i = Multiply()([h_norm, z_id_gamma]) - i = Add()([i, z_id_beta]) - - h_out = AdaptiveAttention()([m, a, i]) - - return h_out - - -def aad_block_mod(inputs, c_out): - h, z_att, z_id = inputs - - h_norm = BatchNormalization()(h) - h = Conv2D(filters=c_out, kernel_size=1, kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(h_norm) - - m = Activation('sigmoid')(h) - - z_att_gamma = Conv2D(filters=c_out, - kernel_size=1, - kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(z_att) - - z_att_beta = Conv2D(filters=c_out, - kernel_size=1, - kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(z_att) - - a = Multiply()([h_norm, z_att_gamma]) - a = Add()([a, z_att_beta]) - - z_id_gamma = Dense(h_norm.shape[-1], - kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(z_id) - - i = Conv2DMod(filters=c_out, - kernel_size=1, - padding='same', - kernel_initializer='he_uniform', - kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))([h_norm, z_id_gamma]) - - h_out = AdaptiveAttention()([m, a, i]) - - return h_out - - -def aad_res_block(inputs, c_in, c_out): - h, z_att, z_id = inputs - - if c_in == c_out: - aad = aad_block([h, z_att, z_id], c_out) - act = ReLU()(aad) - conv = Conv2D(filters=c_out, - kernel_size=3, - padding='same', - kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(act) - - aad = aad_block([conv, z_att, z_id], c_out) - act = ReLU()(aad) - conv = Conv2D(filters=c_out, - kernel_size=3, - padding='same', - kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(act) - - h_out = Add()([h, conv]) - return h_out - else: - aad = aad_block([h, z_att, z_id], c_in) - act = ReLU()(aad) - h_res = Conv2D(filters=c_out, - kernel_size=3, - padding='same', - kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(act) - - aad = aad_block([h, z_att, z_id], c_in) - act = ReLU()(aad) - conv = Conv2D(filters=c_out, - kernel_size=3, - padding='same', - kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(act) - - aad = aad_block([conv, z_att, z_id], c_out) - act = ReLU()(aad) - conv = Conv2D(filters=c_out, - kernel_size=3, - padding='same', - kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(act) - - h_out = Add()([h_res, conv]) - - return h_out - - -def aad_res_block_mod(inputs, c_in, c_out): - h, z_att, z_id = inputs - - if c_in == c_out: - aad = aad_block_mod([h, z_att, z_id], c_out) - - act = ReLU()(aad) - conv = Conv2D(filters=c_out, - kernel_size=3, - padding='same', - kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(act) - - aad = aad_block_mod([conv, z_att, z_id], c_out) - act = ReLU()(aad) - conv = Conv2D(filters=c_out, - kernel_size=3, - padding='same', - kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(act) - - h_out = Add()([h, conv]) - return h_out - else: - aad = aad_block_mod([h, z_att, z_id], c_in) - act = ReLU()(aad) - h_res = Conv2D(filters=c_out, - kernel_size=3, - padding='same', - kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(act) - - aad = aad_block_mod([h, z_att, z_id], c_in) - act = ReLU()(aad) - conv = Conv2D(filters=c_out, - kernel_size=3, - padding='same', - kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(act) - - aad = aad_block_mod([conv, z_att, z_id], c_out) - act = ReLU()(aad) - conv = Conv2D(filters=c_out, - kernel_size=3, - padding='same', - kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(act) - - h_out = Add()([h_res, conv]) - - return h_out - - -class FilteredReLU(Layer): - - def __init__(self, - critically_sampled, - - in_channels, - out_channels, - in_size, - out_size, - in_sampling_rate, - out_sampling_rate, - in_cutoff, - out_cutoff, - in_half_width, - out_half_width, - - conv_kernel = 3, - lrelu_upsampling = 2, - filter_size = 6, - conv_clamp = 256, - use_radial_filters = False, - is_torgb = False, - **kwargs): - super(FilteredReLU, self).__init__(**kwargs) - self.critically_sampled = critically_sampled - - self.in_channels = in_channels - self.out_channels = out_channels - self.in_size = np.broadcast_to(np.asarray(in_size), [2]) - self.out_size = np.broadcast_to(np.asarray(out_size), [2]) - self.in_sampling_rate = in_sampling_rate - self.out_sampling_rate = out_sampling_rate - self.in_cutoff = in_cutoff - self.out_cutoff = out_cutoff - self.in_half_width = in_half_width - self.out_half_width = out_half_width - - self.is_torgb = is_torgb - - self.conv_kernel = 1 if is_torgb else conv_kernel - self.lrelu_upsampling = lrelu_upsampling - self.conv_clamp = conv_clamp - - self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) - - # Up sampling filter - self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) - assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate - self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 - self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, - cutoff=self.in_cutoff, - width=self.in_half_width*2, - fs=self.tmp_sampling_rate) - - # Down sampling filter - self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) - assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate - self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 - self.d_radial = use_radial_filters and not self.critically_sampled - self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, - cutoff=self.out_cutoff, - width=self.out_half_width*2, - fs=self.tmp_sampling_rate, - radial=self.d_radial) - # Compute padding - pad_total = (self.out_size - 1) * self.d_factor + 1 - pad_total -= (self.in_size + self.conv_kernel - 1) * self.u_factor - pad_total += self.u_taps + self.d_taps - 2 - pad_lo = (pad_total + self.u_factor) // 2 - pad_hi = pad_total - pad_lo - self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] - - self.gain = 1 if self.is_torgb else np.sqrt(2) - self.slope = 1 if self.is_torgb else 0.2 - - self.act_funcs = {'linear': - {'func': lambda x, **_: x, - 'def_alpha': 0, - 'def_gain': 1}, - 'lrelu': - {'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), - 'def_alpha': 0.2, - 'def_gain': np.sqrt(2)}, - } - - b_init = tf.zeros_initializer() - self.bias = tf.Variable(initial_value=b_init(shape=(out_channels,), - dtype="float32"), - trainable=True) - - def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): - if numtaps == 1: - return None - - if not radial: - f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) - return f - - x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs - r = np.hypot(*np.meshgrid(x, x)) - f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) - beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) - w = np.kaiser(numtaps, beta) - f *= np.outer(w, w) - f /= np.sum(f) - return f - - def get_filter_size(self, f): - if f is None: - return 1, 1 - assert 1 <= f.ndim <= 2 - return f.shape[-1], f.shape[0] # width, height - - def parse_padding(self, padding): - if isinstance(padding, int): - padding = [padding, padding] - assert isinstance(padding, (list, tuple)) - assert all(isinstance(x, (int, np.integer)) for x in padding) - padding = [int(x) for x in padding] - if len(padding) == 2: - px, py = padding - padding = [px, px, py, py] - px0, px1, py0, py1 = padding - return px0, px1, py0, py1 - - def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): - spec = self.act_funcs[act] - alpha = float(alpha if alpha is not None else spec['def_alpha']) - gain = float(gain if gain is not None else spec['def_gain']) - clamp = float(clamp if clamp is not None else -1) - - if b is not None: - x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) - x = spec['func'](x, alpha=alpha) - - if gain != 1: - x = x * gain - - if clamp >= 0: - x = tf.clip_by_value(x, -clamp, clamp) - return x - - def parse_scaling(self, scaling): - if isinstance(scaling, int): - scaling = [scaling, scaling] - sx, sy = scaling - assert sx >= 1 and sy >= 1 - return sx, sy - - def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): - if f is None: - f = tf.ones([1, 1], dtype=tf.float32) - - batch_size, in_height, in_width, num_channels = x.shape - - upx, upy = self.parse_scaling(up) - downx, downy = self.parse_scaling(down) - padx0, padx1, pady0, pady1 = self.parse_padding(padding) - - upW = in_width * upx + padx0 + padx1 - upH = in_height * upy + pady0 + pady1 - assert upW >= f.shape[-1] and upH >= f.shape[0] - - # Channel first format. - x = tf.transpose(x, perm=[0, 3, 1, 2]) - - # Upsample by inserting zeros. - x = tf.reshape(x, [num_channels, batch_size, in_height, 1, in_width, 1]) - x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) - x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) - - # Pad or crop. - x = tf.pad(x, [[0, 0], [0, 0], - [tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], - [tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) - x = x[:, :, - tf.math.maximum(-pady0, 0) : x.shape[2] - tf.math.maximum(-pady1, 0), - tf.math.maximum(-padx0, 0) : x.shape[3] - tf.math.maximum(-padx1, 0)] - - # Setup filter. - f = f * (gain ** (f.ndim / 2)) - f = tf.cast(f, dtype=x.dtype) - if not flip_filter: - f = tf.reverse(f, axis=[-1]) - f = tf.reshape(f, shape=(1, 1, f.shape[-1])) - f = tf.repeat(f, repeats=num_channels, axis=0) - - if tf.rank(f) == 4: - f_0 = tf.transpose(f, perm=[2, 3, 1, 0]) - x = tf.nn.conv2d(x, f_0, 1, 'VALID') - else: - f_0 = tf.expand_dims(f, axis=2) - f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) - - f_1 = tf.expand_dims(f, axis=3) - f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) - - x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') - x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') - - x = x[:, :, ::downy, ::downx] - - # Back to channel last. - x = tf.transpose(x, perm=[0, 2, 3, 1]) - return x - - - def filtered_lrelu(self, - x, fu=None, fd=None, b=None, - up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): - #fu_w, fu_h = self.get_filter_size(fu) - #fd_w, fd_h = self.get_filter_size(fd) - - px0, px1, py0, py1 = self.parse_padding(padding) - - #batch_size, in_h, in_w, channels = x.shape - #in_dtype = x.dtype - #out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down - #out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down - - x = self.bias_act(x=x, b=b) - x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) - x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) - x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) - - return x - - - def call(self, inputs): - return self.filtered_lrelu(inputs, - fu=self.u_filter, - fd=self.d_filter, - b=self.bias, - up=self.u_factor, - down=self.d_factor, - padding=self.padding, - gain=self.gain, - slope=self.slope, - clamp=self.conv_clamp) - - def get_config(self): - base_config = super(FilteredReLU, self).get_config() - return base_config - - -class SynthesisLayer(Layer): - - def __init__(self, - critically_sampled, - - in_channels, - out_channels, - in_size, - out_size, - in_sampling_rate, - out_sampling_rate, - in_cutoff, - out_cutoff, - in_half_width, - out_half_width, - - conv_kernel = 3, - lrelu_upsampling = 2, - filter_size = 6, - conv_clamp = 256, - use_radial_filters = False, - is_torgb = False, - **kwargs): - super(SynthesisLayer, self).__init__(**kwargs) - self.critically_sampled = critically_sampled - - self.in_channels = in_channels - self.out_channels = out_channels - self.in_size = np.broadcast_to(np.asarray(in_size), [2]) - self.out_size = np.broadcast_to(np.asarray(out_size), [2]) - self.in_sampling_rate = in_sampling_rate - self.out_sampling_rate = out_sampling_rate - self.in_cutoff = in_cutoff - self.out_cutoff = out_cutoff - self.in_half_width = in_half_width - self.out_half_width = out_half_width - - self.is_torgb = is_torgb - - self.conv_kernel = 1 if is_torgb else conv_kernel - self.lrelu_upsampling = lrelu_upsampling - self.conv_clamp = conv_clamp - - self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) - - # Up sampling filter - self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) - assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate - self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 - self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, - cutoff=self.in_cutoff, - width=self.in_half_width*2, - fs=self.tmp_sampling_rate) - - # Down sampling filter - self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) - assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate - self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 - self.d_radial = use_radial_filters and not self.critically_sampled - self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, - cutoff=self.out_cutoff, - width=self.out_half_width*2, - fs=self.tmp_sampling_rate, - radial=self.d_radial) - # Compute padding - pad_total = (self.out_size - 1) * self.d_factor + 1 - pad_total -= (self.in_size + self.conv_kernel - 1) * self.u_factor - pad_total += self.u_taps + self.d_taps - 2 - pad_lo = (pad_total + self.u_factor) // 2 - pad_hi = pad_total - pad_lo - self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] - - self.gain = 1 if self.is_torgb else np.sqrt(2) - self.slope = 1 if self.is_torgb else 0.2 - - self.act_funcs = {'linear': - {'func': lambda x, **_: x, - 'def_alpha': 0, - 'def_gain': 1}, - 'lrelu': - {'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), - 'def_alpha': 0.2, - 'def_gain': np.sqrt(2)}, - } - - b_init = tf.zeros_initializer() - self.bias = tf.Variable(initial_value=b_init(shape=(out_channels,), - dtype="float32"), - trainable=True) - self.affine = Dense(self.in_channels) - self.conv = Conv2DMod(self.out_channels, kernel_size=self.conv_kernel, padding='same') - - def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): - if numtaps == 1: - return None - - if not radial: - f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) - return f - - x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs - r = np.hypot(*np.meshgrid(x, x)) - f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) - beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) - w = np.kaiser(numtaps, beta) - f *= np.outer(w, w) - f /= np.sum(f) - return f - - def get_filter_size(self, f): - if f is None: - return 1, 1 - assert 1 <= f.ndim <= 2 - return f.shape[-1], f.shape[0] # width, height - - def parse_padding(self, padding): - if isinstance(padding, int): - padding = [padding, padding] - assert isinstance(padding, (list, tuple)) - assert all(isinstance(x, (int, np.integer)) for x in padding) - padding = [int(x) for x in padding] - if len(padding) == 2: - px, py = padding - padding = [px, px, py, py] - px0, px1, py0, py1 = padding - return px0, px1, py0, py1 - - def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): - spec = self.act_funcs[act] - alpha = float(alpha if alpha is not None else spec['def_alpha']) - gain = float(gain if gain is not None else spec['def_gain']) - clamp = float(clamp if clamp is not None else -1) - - if b is not None: - x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) - x = spec['func'](x, alpha=alpha) - - if gain != 1: - x = x * gain - - if clamp >= 0: - x = tf.clip_by_value(x, -clamp, clamp) - return x - - def parse_scaling(self, scaling): - if isinstance(scaling, int): - scaling = [scaling, scaling] - sx, sy = scaling - assert sx >= 1 and sy >= 1 - return sx, sy - - def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): - if f is None: - f = tf.ones([1, 1], dtype=tf.float32) - - batch_size, in_height, in_width, num_channels = x.shape - - upx, upy = self.parse_scaling(up) - downx, downy = self.parse_scaling(down) - padx0, padx1, pady0, pady1 = self.parse_padding(padding) - - upW = in_width * upx + padx0 + padx1 - upH = in_height * upy + pady0 + pady1 - assert upW >= f.shape[-1] and upH >= f.shape[0] - - # Channel first format. - x = tf.transpose(x, perm=[0, 3, 1, 2]) - - # Upsample by inserting zeros. - x = tf.reshape(x, [num_channels, batch_size, in_height, 1, in_width, 1]) - x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) - x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) - - # Pad or crop. - x = tf.pad(x, [[0, 0], [0, 0], - [tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], - [tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) - x = x[:, :, tf.math.maximum(-pady0, 0) : x.shape[2] - tf.math.maximum(-pady1, 0), tf.math.maximum(-padx0, 0) : x.shape[3] - tf.math.maximum(-padx1, 0)] - - # Setup filter. - f = f * (gain ** (f.ndim / 2)) - f = tf.cast(f, dtype=x.dtype) - if not flip_filter: - f = tf.reverse(f, axis=[-1]) - f = tf.reshape(f, shape=(1, 1, f.shape[-1])) - f = tf.repeat(f, repeats=num_channels, axis=0) - - if tf.rank(f) == 4: - f_0 = tf.transpose(f, perm=[2, 3, 1, 0]) - x = tf.nn.conv2d(x, f_0, 1, 'VALID') - else: - f_0 = tf.expand_dims(f, axis=2) - f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) - - f_1 = tf.expand_dims(f, axis=3) - f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) - - x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') - x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') - - x = x[:, :, ::downy, ::downx] - - # Back to channel last. - x = tf.transpose(x, perm=[0, 2, 3, 1]) - return x - - - def filtered_lrelu(self, - x, fu=None, fd=None, b=None, - up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): - #fu_w, fu_h = self.get_filter_size(fu) - #fd_w, fd_h = self.get_filter_size(fd) - - px0, px1, py0, py1 = self.parse_padding(padding) - - #batch_size, in_h, in_w, channels = x.shape - #in_dtype = x.dtype - #out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down - #out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down - - x = self.bias_act(x=x, b=b) - x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) - x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) - x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) - - return x - - - def call(self, inputs): - x, w = inputs - styles = self.affine(w) - x = self.conv([x, styles]) - x = self.filtered_lrelu(x, - fu=self.u_filter, - fd=self.d_filter, - b=self.bias, - up=self.u_factor, - down=self.d_factor, - padding=self.padding, - gain=self.gain, - slope=self.slope, - clamp=self.conv_clamp) - return x - - def get_config(self): - base_config = super(SynthesisLayer, self).get_config() - return base_config - - -class SynthesisLayerNoMod(Layer): - - def __init__(self, - critically_sampled, - - in_channels, - out_channels, - in_size, - out_size, - in_sampling_rate, - out_sampling_rate, - in_cutoff, - out_cutoff, - in_half_width, - out_half_width, - - conv_kernel = 3, - lrelu_upsampling = 2, - filter_size = 6, - conv_clamp = 256, - use_radial_filters = False, - is_torgb = False, - batch_size = 10, - **kwargs): - super(SynthesisLayerNoMod, self).__init__(**kwargs) - self.critically_sampled = critically_sampled - self.bs = batch_size - - self.in_channels = in_channels - self.out_channels = out_channels - self.in_size = np.broadcast_to(np.asarray(in_size), [2]) - self.out_size = np.broadcast_to(np.asarray(out_size), [2]) - self.in_sampling_rate = in_sampling_rate - self.out_sampling_rate = out_sampling_rate - self.in_cutoff = in_cutoff - self.out_cutoff = out_cutoff - self.in_half_width = in_half_width - self.out_half_width = out_half_width - - self.is_torgb = is_torgb - - self.conv_kernel = 1 if is_torgb else conv_kernel - self.lrelu_upsampling = lrelu_upsampling - self.conv_clamp = conv_clamp - - self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) - - # Up sampling filter - self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) - assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate - self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 - self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, - cutoff=self.in_cutoff, - width=self.in_half_width*2, - fs=self.tmp_sampling_rate) - - # Down sampling filter - self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) - assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate - self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 - self.d_radial = use_radial_filters and not self.critically_sampled - self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, - cutoff=self.out_cutoff, - width=self.out_half_width*2, - fs=self.tmp_sampling_rate, - radial=self.d_radial) - # Compute padding - pad_total = (self.out_size - 1) * self.d_factor + 1 - pad_total -= (self.in_size + self.conv_kernel - 1) * self.u_factor - pad_total += self.u_taps + self.d_taps - 2 - pad_lo = (pad_total + self.u_factor) // 2 - pad_hi = pad_total - pad_lo - self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] - - self.gain = 1 if self.is_torgb else np.sqrt(2) - self.slope = 1 if self.is_torgb else 0.2 - - self.act_funcs = {'linear': - {'func': lambda x, **_: x, - 'def_alpha': 0, - 'def_gain': 1}, - 'lrelu': - {'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), - 'def_alpha': 0.2, - 'def_gain': np.sqrt(2)}, - } - - b_init = tf.zeros_initializer() - self.bias = tf.Variable(initial_value=b_init(shape=(out_channels,), - dtype="float32"), - trainable=True) - self.conv = Conv2D(self.out_channels, kernel_size=self.conv_kernel, padding='same') - - def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): - if numtaps == 1: - return None - - if not radial: - f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) - return f - - x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs - r = np.hypot(*np.meshgrid(x, x)) - f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) - beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) - w = np.kaiser(numtaps, beta) - f *= np.outer(w, w) - f /= np.sum(f) - return f - - def get_filter_size(self, f): - if f is None: - return 1, 1 - assert 1 <= f.ndim <= 2 - return f.shape[-1], f.shape[0] # width, height - - def parse_padding(self, padding): - if isinstance(padding, int): - padding = [padding, padding] - assert isinstance(padding, (list, tuple)) - assert all(isinstance(x, (int, np.integer)) for x in padding) - padding = [int(x) for x in padding] - if len(padding) == 2: - px, py = padding - padding = [px, px, py, py] - px0, px1, py0, py1 = padding - return px0, px1, py0, py1 - - @tf.function - def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): - spec = self.act_funcs[act] - alpha = float(alpha if alpha is not None else spec['def_alpha']) - gain = float(gain if gain is not None else spec['def_gain']) - clamp = float(clamp if clamp is not None else -1) - - if b is not None: - x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) - x = spec['func'](x, alpha=alpha) - - if gain != 1: - x = x * gain - - if clamp >= 0: - x = tf.clip_by_value(x, -clamp, clamp) - return x - - def parse_scaling(self, scaling): - if isinstance(scaling, int): - scaling = [scaling, scaling] - sx, sy = scaling - assert sx >= 1 and sy >= 1 - return sx, sy - - @tf.function - def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): - if f is None: - f = tf.ones([1, 1], dtype=tf.float32) - - batch_size, in_height, in_width, num_channels = x.shape - batch_size = tf.shape(x)[0] - - upx, upy = self.parse_scaling(up) - downx, downy = self.parse_scaling(down) - padx0, padx1, pady0, pady1 = self.parse_padding(padding) - - upW = in_width * upx + padx0 + padx1 - upH = in_height * upy + pady0 + pady1 - assert upW >= f.shape[-1] and upH >= f.shape[0] - - # Channel first format. - x = tf.transpose(x, perm=[0, 3, 1, 2]) - # Upsample by inserting zeros. - x = tf.reshape(x, [batch_size, num_channels, in_height, 1, in_width, 1]) - x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) - x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) - - # Pad or crop. - x = tf.pad(x, [[0, 0], [0, 0], - [tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], - [tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) - x = x[:, :, - tf.math.maximum(-pady0, 0) : x.shape[2] - tf.math.maximum(-pady1, 0), - tf.math.maximum(-padx0, 0) : x.shape[3] - tf.math.maximum(-padx1, 0)] - - # Setup filter. - f = f * (gain ** (tf.rank(f) / 2)) - f = tf.cast(f, dtype=x.dtype) - if not flip_filter: - f = tf.reverse(f, axis=[-1]) - f = tf.reshape(f, shape=(1, 1, f.shape[-1])) - f = tf.repeat(f, repeats=num_channels, axis=0) - - #if tf.rank(f) == 500: - # f_0 = tf.transpose(f, perm=[2, 3, 1, 0]) - # x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') - #else: - f_0 = tf.expand_dims(f, axis=2) - f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) - - f_1 = tf.expand_dims(f, axis=3) - f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) - - x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') - x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') - - x = x[:, :, ::downy, ::downx] - - # Back to channel last. - x = tf.transpose(x, perm=[0, 2, 3, 1]) - return x - - @tf.function - def filtered_lrelu(self, - x, fu=None, fd=None, b=None, - up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): - #fu_w, fu_h = self.get_filter_size(fu) - #fd_w, fd_h = self.get_filter_size(fd) - - px0, px1, py0, py1 = self.parse_padding(padding) - - #batch_size, in_h, in_w, channels = x.shape - #in_dtype = x.dtype - #out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down - #out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down - - x = self.bias_act(x=x, b=b) - x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) - x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) - x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) - - return x - - - def call(self, inputs): - x = inputs - x = self.conv(x) - x = self.filtered_lrelu(x, - fu=self.u_filter, - fd=self.d_filter, - b=self.bias, - up=self.u_factor, - down=self.d_factor, - padding=self.padding, - gain=self.gain, - slope=self.slope, - clamp=self.conv_clamp) - return x - - def get_config(self): - base_config = super(SynthesisLayer, self).get_config() - return base_config - - -class SynthesisLayerNoModBN(Layer): - - def __init__(self, - critically_sampled, - - in_channels, - out_channels, - in_size, - out_size, - in_sampling_rate, - out_sampling_rate, - in_cutoff, - out_cutoff, - in_half_width, - out_half_width, - - conv_kernel = 3, - lrelu_upsampling = 2, - filter_size = 6, - conv_clamp = 256, - use_radial_filters = False, - is_torgb = False, - batch_size = 10, - **kwargs): - super(SynthesisLayerNoModBN, self).__init__(**kwargs) - self.critically_sampled = critically_sampled - self.bs = batch_size - - self.in_channels = in_channels - self.out_channels = out_channels - self.in_size = np.broadcast_to(np.asarray(in_size), [2]) - self.out_size = np.broadcast_to(np.asarray(out_size), [2]) - self.in_sampling_rate = in_sampling_rate - self.out_sampling_rate = out_sampling_rate - self.in_cutoff = in_cutoff - self.out_cutoff = out_cutoff - self.in_half_width = in_half_width - self.out_half_width = out_half_width - - self.is_torgb = is_torgb - - self.conv_kernel = 1 if is_torgb else conv_kernel - self.lrelu_upsampling = lrelu_upsampling - self.conv_clamp = conv_clamp - - self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1.0 if is_torgb else lrelu_upsampling) - - # Up sampling filter - self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) - assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate - self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 - self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, - cutoff=self.in_cutoff, - width=self.in_half_width*2, - fs=self.tmp_sampling_rate) - - # Down sampling filter - self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) - assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate - self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 - self.d_radial = use_radial_filters and not self.critically_sampled - self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, - cutoff=self.out_cutoff, - width=self.out_half_width*2, - fs=self.tmp_sampling_rate, - radial=self.d_radial) - # Compute padding - pad_total = (self.out_size - 1) * self.d_factor + 1 - pad_total -= (self.in_size) * self.u_factor - pad_total += self.u_taps + self.d_taps - 2 - pad_lo = (pad_total + self.u_factor) // 2 - pad_hi = pad_total - pad_lo - self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] - - self.gain = 1 if self.is_torgb else np.sqrt(2) - self.slope = 1 if self.is_torgb else 0.2 - - self.act_funcs = {'linear': - {'func': lambda x, **_: x, - 'def_alpha': 0, - 'def_gain': 1}, - 'lrelu': - {'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), - 'def_alpha': 0.2, - 'def_gain': np.sqrt(2)}, - } - - b_init = tf.zeros_initializer() - self.bias = tf.Variable(initial_value=b_init(shape=(out_channels,), - dtype="float32"), - trainable=True) - self.conv = Conv2D(self.out_channels, kernel_size=self.conv_kernel, padding='same') - self.bn = BatchNormalization() - - def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): - if numtaps == 1: - return None - - if not radial: - f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) - return f - - x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs - r = np.hypot(*np.meshgrid(x, x)) - f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) - beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) - w = np.kaiser(numtaps, beta) - f *= np.outer(w, w) - f /= np.sum(f) - return f - - def get_filter_size(self, f): - if f is None: - return 1, 1 - assert 1 <= f.ndim <= 2 - return f.shape[-1], f.shape[0] # width, height - - def parse_padding(self, padding): - if isinstance(padding, int): - padding = [padding, padding] - assert isinstance(padding, (list, tuple)) - assert all(isinstance(x, (int, np.integer)) for x in padding) - padding = [int(x) for x in padding] - if len(padding) == 2: - px, py = padding - padding = [px, px, py, py] - px0, px1, py0, py1 = padding - return px0, px1, py0, py1 - - @tf.function - def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): - spec = self.act_funcs[act] - alpha = float(alpha if alpha is not None else spec['def_alpha']) - gain = tf.cast(gain if gain is not None else spec['def_gain'], tf.float32) - clamp = float(clamp if clamp is not None else -1) - - if b is not None: - x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) - x = spec['func'](x, alpha=alpha) - - x = x * gain - - if clamp >= 0: - x = tf.clip_by_value(x, -clamp, clamp) - return x - - def parse_scaling(self, scaling): - if isinstance(scaling, int): - scaling = [scaling, scaling] - sx, sy = scaling - assert sx >= 1 and sy >= 1 - return sx, sy - - @tf.function - def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): - if f is None: - f = tf.ones([1, 1], dtype=tf.float32) - - batch_size, in_height, in_width, num_channels = x.shape - batch_size = tf.shape(x)[0] - - upx, upy = self.parse_scaling(up) - downx, downy = self.parse_scaling(down) - padx0, padx1, pady0, pady1 = self.parse_padding(padding) - - upW = in_width * upx + padx0 + padx1 - upH = in_height * upy + pady0 + pady1 - assert upW >= f.shape[-1] and upH >= f.shape[0] - - # Channel first format. - x = tf.transpose(x, perm=[0, 3, 1, 2]) - # Upsample by inserting zeros. - x = tf.reshape(x, [batch_size, num_channels, in_height, 1, in_width, 1]) - x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) - x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) - - # Pad or crop. - x = tf.pad(x, [[0, 0], [0, 0], - [tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], - [tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) - x = x[:, :, - tf.math.maximum(-pady0, 0) : x.shape[2] - tf.math.maximum(-pady1, 0), - tf.math.maximum(-padx0, 0) : x.shape[3] - tf.math.maximum(-padx1, 0)] - - # Setup filter. - f = f * (gain ** (tf.rank(f) / 2)) - f = tf.cast(f, dtype=x.dtype) - if not flip_filter: - f = tf.reverse(f, axis=[-1]) - f = tf.reshape(f, shape=(1, 1, f.shape[-1])) - f = tf.repeat(f, repeats=num_channels, axis=0) - - #if tf.rank(f) == 500: - # f_0 = tf.transpose(f, perm=[2, 3, 1, 0]) - # x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') - #else: - f_0 = tf.expand_dims(f, axis=2) - f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) - - f_1 = tf.expand_dims(f, axis=3) - f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) - - x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') - x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') - - x = x[:, :, ::downy, ::downx] - - # Back to channel last. - x = tf.transpose(x, perm=[0, 2, 3, 1]) - return x - - @tf.function - def filtered_lrelu(self, - x, fu=None, fd=None, b=None, - up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): - #fu_w, fu_h = self.get_filter_size(fu) - #fd_w, fd_h = self.get_filter_size(fd) - - px0, px1, py0, py1 = self.parse_padding(padding) - - #batch_size, in_h, in_w, channels = x.shape - #in_dtype = x.dtype - #out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down - #out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down - - x = self.bias_act(x=x, b=b) - x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) - x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) - x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) - - return x - - - def call(self, inputs): - x = inputs - x = self.conv(x) - x = self.bn(x) - x = self.filtered_lrelu(x, - fu=self.u_filter, - fd=self.d_filter, - b=self.bias, - up=self.u_factor, - down=self.d_factor, - padding=self.padding, - gain=self.gain, - slope=self.slope, - clamp=self.conv_clamp) - return x - - def get_config(self): - base_config = super(SynthesisLayerNoModBN, self).get_config() - return base_config - - -class SynthesisInput(Layer): - def __init__(self, - w_dim, - channels, - size, - sampling_rate, - bandwidth, - **kwargs): - super(SynthesisInput, self).__init__(**kwargs) - self.w_dim = w_dim - self.channels = channels - self.size = np.broadcast_to(np.asarray(size), [2]) - self.sampling_rate = sampling_rate - self.bandwidth = bandwidth - - # Draw random frequencies from uniform 2D disc. - freqs = np.random.normal(size=(int(channels[0]), 2)) - radii = np.sqrt(np.sum(np.square(freqs), axis=1, keepdims=True)) - freqs /= radii * np.power(np.exp(np.square(radii)), 0.25) - freqs *= bandwidth - phases = np.random.uniform(size=[int(channels[0])]) - 0.5 - - # Setup parameters and buffers. - w_init = tf.random_normal_initializer() - self.weight = tf.Variable(initial_value=w_init(shape=(self.channels, self.channels), - dtype="float32"), rainable=True) - self.affine = Dense(4, kernel_initializer=tf.zeros_initializer, bias_initializer=tf.zeros_initializer) - self.transform = tf.eye(3, 3) - self.freqs = tf.constant(freqs) - self.phases = tf.constant(phases) - - def call(self, w): - # Batch dimension - transforms = tf.expand_dims(self.transform, axis=0) - freqs = tf.expand_dims(self.freqs, axis=0) - phases = tf.expand_dims(self.phases, axis=0) - - # Apply learned transformation. - t = self.affine(w) # t = (r_c, r_s, t_x, t_y) - t = t / tf.linalg.norm(t[:, :2], axis=1, keepdims=True) - # Inverse rotation wrt. resulting image. - m_r = tf.repeat(tf.expand_dims(tf.eye(3), axis=0), repeats=w.shape[0], axis=0) - m_r[:, 0, 0] = t[:, 0] # r'_c - m_r[:, 0, 1] = -t[:, 1] # r'_s - m_r[:, 1, 0] = t[:, 1] # r'_s - m_r[:, 1, 1] = t[:, 0] # r'_c - # Inverse translation wrt. resulting image. - m_t = tf.repeat(tf.expand_dims(tf.eye(3), axis=0), repeats=w.shape[0], axis=0) - m_t[:, 0, 2] = -t[:, 2] # t'_x - m_t[:, 1, 2] = -t[:, 3] # t'_y - transforms = m_r @ m_t @ transforms - - # Transform frequencies. - phases = phases + tf.expand_dims(freqs @ transforms[:, :2, 2:], axis=2) - freqs = freqs @ transforms[:, :2, :2] - - # Dampen out-of-band frequencies that may occur due to the user-specified transform. - amplitudes = tf.clip_by_value(1 - (tf.linalg.norm(freqs, axis=1, keepdims=True) - self.bandwidth) / (self.sampling_rate / 2 - self.bandwidth), 0, 1) - - - - def get_config(self): - base_config = super(SynthesisInput, self).get_config() - return base_config - - -class SynthesisLayerFS(Layer): - - def __init__(self, - critically_sampled, - - in_channels, - out_channels, - in_size, - out_size, - in_sampling_rate, - out_sampling_rate, - in_cutoff, - out_cutoff, - in_half_width, - out_half_width, - - conv_kernel = 3, - lrelu_upsampling = 2, - filter_size = 6, - conv_clamp = 256, - use_radial_filters = False, - is_torgb = False, - **kwargs): - super(SynthesisLayerFS, self).__init__(**kwargs) - self.critically_sampled = critically_sampled - - self.in_channels = in_channels - self.out_channels = out_channels - self.in_size = np.broadcast_to(np.asarray(in_size), [2]) - self.out_size = np.broadcast_to(np.asarray(out_size), [2]) - self.in_sampling_rate = in_sampling_rate - self.out_sampling_rate = out_sampling_rate - self.in_cutoff = in_cutoff - self.out_cutoff = out_cutoff - self.in_half_width = in_half_width - self.out_half_width = out_half_width - - self.is_torgb = is_torgb - - self.conv_kernel = 1 if is_torgb else conv_kernel - self.lrelu_upsampling = lrelu_upsampling - self.conv_clamp = conv_clamp - - self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) - - # Up sampling filter - self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) - assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate - self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 - self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, - cutoff=self.in_cutoff, - width=self.in_half_width*2, - fs=self.tmp_sampling_rate) - - # Down sampling filter - self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) - assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate - self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 - self.d_radial = use_radial_filters and not self.critically_sampled - self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, - cutoff=self.out_cutoff, - width=self.out_half_width*2, - fs=self.tmp_sampling_rate, - radial=self.d_radial) - # Compute padding - pad_total = (self.out_size - 1) * self.d_factor + 1 - pad_total -= (self.in_size + self.conv_kernel - 1) * self.u_factor - pad_total += self.u_taps + self.d_taps - 2 - pad_lo = (pad_total + self.u_factor) // 2 - pad_hi = pad_total - pad_lo - self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] - - self.gain = 1 if self.is_torgb else np.sqrt(2) - self.slope = 1 if self.is_torgb else 0.2 - - self.act_funcs = {'linear': - {'func': lambda x, **_: x, - 'def_alpha': 0, - 'def_gain': 1}, - 'lrelu': - {'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), - 'def_alpha': 0.2, - 'def_gain': np.sqrt(2)}, - } - - b_init = tf.zeros_initializer() - self.bias = tf.Variable(initial_value=b_init(shape=(out_channels,), - dtype="float32"), - trainable=True) - self.affine = Dense(self.out_channels) - self.conv_mod = Conv2DMod(self.out_channels, kernel_size=self.conv_kernel, padding='same') - self.bn = BatchNormalization() - self.conv_gamma = Conv2D(self.out_channels, kernel_size=1) - self.conv_beta = Conv2D(self.out_channels, kernel_size=1) - self.conv_gate = Conv2D(self.out_channels, kernel_size=1) - self.conv_final = Conv2D(self.out_channels, kernel_size=self.conv_kernel, padding='same') - - - def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): - if numtaps == 1: - return None - - if not radial: - f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) - return f - - x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs - r = np.hypot(*np.meshgrid(x, x)) - f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) - beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) - w = np.kaiser(numtaps, beta) - f *= np.outer(w, w) - f /= np.sum(f) - return f - - def get_filter_size(self, f): - if f is None: - return 1, 1 - assert 1 <= f.ndim <= 2 - return f.shape[-1], f.shape[0] # width, height - - def parse_padding(self, padding): - if isinstance(padding, int): - padding = [padding, padding] - assert isinstance(padding, (list, tuple)) - assert all(isinstance(x, (int, np.integer)) for x in padding) - padding = [int(x) for x in padding] - if len(padding) == 2: - px, py = padding - padding = [px, px, py, py] - px0, px1, py0, py1 = padding - return px0, px1, py0, py1 - - @tf.function - def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): - spec = self.act_funcs[act] - alpha = float(alpha if alpha is not None else spec['def_alpha']) - gain = tf.cast(gain if gain is not None else spec['def_gain'], tf.float32) - clamp = float(clamp if clamp is not None else -1) - - if b is not None: - x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) - x = spec['func'](x, alpha=alpha) - - x = x * gain - - if clamp >= 0: - x = tf.clip_by_value(x, -clamp, clamp) - return x - - def parse_scaling(self, scaling): - if isinstance(scaling, int): - scaling = [scaling, scaling] - sx, sy = scaling - assert sx >= 1 and sy >= 1 - return sx, sy - - @tf.function - def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): - if f is None: - f = tf.ones([1, 1], dtype=tf.float32) - - batch_size, in_height, in_width, num_channels = x.shape - batch_size = tf.shape(x)[0] - - upx, upy = self.parse_scaling(up) - downx, downy = self.parse_scaling(down) - padx0, padx1, pady0, pady1 = self.parse_padding(padding) - - upW = in_width * upx + padx0 + padx1 - upH = in_height * upy + pady0 + pady1 - assert upW >= f.shape[-1] and upH >= f.shape[0] - - # Channel first format. - x = tf.transpose(x, perm=[0, 3, 1, 2]) - # Upsample by inserting zeros. - x = tf.reshape(x, [batch_size, num_channels, in_height, 1, in_width, 1]) - x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) - x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) - - # Pad or crop. - x = tf.pad(x, [[0, 0], [0, 0], - [tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], - [tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) - x = x[:, :, - tf.math.maximum(-pady0, 0): x.shape[2] - tf.math.maximum(-pady1, 0), - tf.math.maximum(-padx0, 0): x.shape[3] - tf.math.maximum(-padx1, 0)] - - # Setup filter. - f = f * (gain ** (tf.rank(f) / 2)) - f = tf.cast(f, dtype=x.dtype) - if not flip_filter: - f = tf.reverse(f, axis=[-1]) - f = tf.reshape(f, shape=(1, 1, f.shape[-1])) - f = tf.repeat(f, repeats=num_channels, axis=0) - - # if tf.rank(f) == 500: - # f_0 = tf.transpose(f, perm=[2, 3, 1, 0]) - # x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') - # else: - f_0 = tf.expand_dims(f, axis=2) - f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) - - f_1 = tf.expand_dims(f, axis=3) - f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) - - x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') - x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') - - x = x[:, :, ::downy, ::downx] - - # Back to channel last. - x = tf.transpose(x, perm=[0, 2, 3, 1]) - return x - - @tf.function - def filtered_lrelu(self, - x, fu=None, fd=None, b=None, - up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): - #fu_w, fu_h = self.get_filter_size(fu) - #fd_w, fd_h = self.get_filter_size(fd) - - px0, px1, py0, py1 = self.parse_padding(padding) - - #batch_size, in_h, in_w, channels = x.shape - #in_dtype = x.dtype - #out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down - #out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down - - x = self.bias_act(x=x, b=b) - x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) - x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) - x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) - - return x - - @tf.function - def aadm(self, x, w, a): - w_affine = self.affine(w) - x_norm = self.bn(x) - - x_id = self.conv_mod([x_norm, w_affine]) - - gate = self.conv_gate(x_norm) - gate = tf.nn.sigmoid(gate) - - x_att_beta = self.conv_beta(a) - x_att_gamma = self.conv_gamma(a) - - x_att = x_norm * x_att_beta + x_att_gamma - - h = x_id * gate + (1 - gate) * x_att - - return h - - - def call(self, inputs): - x, w, a = inputs - x = self.conv_final(x) - x = self.aadm(x, w, a) - x = self.filtered_lrelu(x, - fu=self.u_filter, - fd=self.d_filter, - b=self.bias, - up=self.u_factor, - down=self.d_factor, - padding=self.padding, - gain=self.gain, - slope=self.slope, - clamp=self.conv_clamp) - return x - - def get_config(self): - base_config = super(SynthesisLayerFS, self).get_config() - return base_config - - -class SynthesisLayerUpDownOnly(Layer): - - def __init__(self, - critically_sampled, - - in_channels, - out_channels, - in_size, - out_size, - in_sampling_rate, - out_sampling_rate, - in_cutoff, - out_cutoff, - in_half_width, - out_half_width, - - conv_kernel = 3, - lrelu_upsampling = 2, - filter_size = 6, - conv_clamp = 256, - use_radial_filters = False, - is_torgb = False, - **kwargs): - super(SynthesisLayerUpDownOnly, self).__init__(**kwargs) - self.critically_sampled = critically_sampled - - self.in_channels = in_channels - self.out_channels = out_channels - self.in_size = np.broadcast_to(np.asarray(in_size), [2]) - self.out_size = np.broadcast_to(np.asarray(out_size), [2]) - self.in_sampling_rate = in_sampling_rate - self.out_sampling_rate = out_sampling_rate - self.in_cutoff = in_cutoff - self.out_cutoff = out_cutoff - self.in_half_width = in_half_width - self.out_half_width = out_half_width - - self.is_torgb = is_torgb - - self.conv_kernel = 1 if is_torgb else conv_kernel - self.lrelu_upsampling = lrelu_upsampling - self.conv_clamp = conv_clamp - - self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) - - # Up sampling filter - self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) - assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate - self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 - self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, - cutoff=self.in_cutoff, - width=self.in_half_width*2, - fs=self.tmp_sampling_rate) - - # Down sampling filter - self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) - assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate - self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 - self.d_radial = use_radial_filters and not self.critically_sampled - self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, - cutoff=self.out_cutoff, - width=self.out_half_width*2, - fs=self.tmp_sampling_rate, - radial=self.d_radial) - # Compute padding - pad_total = (self.out_size - 1) * self.d_factor + 1 - pad_total -= (self.in_size + self.conv_kernel - 1) * self.u_factor - pad_total += self.u_taps + self.d_taps - 2 - pad_lo = (pad_total + self.u_factor) // 2 - pad_hi = pad_total - pad_lo - self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] - - self.gain = 1 if self.is_torgb else np.sqrt(2) - self.slope = 1 if self.is_torgb else 0.2 - - self.act_funcs = {'linear': - {'func': lambda x, **_: x, - 'def_alpha': 0, - 'def_gain': 1}, - 'lrelu': - {'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), - 'def_alpha': 0.2, - 'def_gain': np.sqrt(2)}, - } - - def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): - if numtaps == 1: - return None - - if not radial: - f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) - return f - - x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs - r = np.hypot(*np.meshgrid(x, x)) - f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) - beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) - w = np.kaiser(numtaps, beta) - f *= np.outer(w, w) - f /= np.sum(f) - return f - - def get_filter_size(self, f): - if f is None: - return 1, 1 - assert 1 <= f.ndim <= 2 - return f.shape[-1], f.shape[0] # width, height - - def parse_padding(self, padding): - if isinstance(padding, int): - padding = [padding, padding] - assert isinstance(padding, (list, tuple)) - assert all(isinstance(x, (int, np.integer)) for x in padding) - padding = [int(x) for x in padding] - if len(padding) == 2: - px, py = padding - padding = [px, px, py, py] - px0, px1, py0, py1 = padding - return px0, px1, py0, py1 - - def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): - spec = self.act_funcs[act] - alpha = float(alpha if alpha is not None else spec['def_alpha']) - gain = float(gain if gain is not None else spec['def_gain']) - clamp = float(clamp if clamp is not None else -1) - - if b is not None: - x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) - x = spec['func'](x, alpha=alpha) - - if gain != 1: - x = x * gain - - if clamp >= 0: - x = tf.clip_by_value(x, -clamp, clamp) - return x - - def parse_scaling(self, scaling): - if isinstance(scaling, int): - scaling = [scaling, scaling] - sx, sy = scaling - assert sx >= 1 and sy >= 1 - return sx, sy - - def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): - if f is None: - f = tf.ones([1, 1], dtype=tf.float32) - - batch_size, in_height, in_width, num_channels = x.shape - - upx, upy = self.parse_scaling(up) - downx, downy = self.parse_scaling(down) - padx0, padx1, pady0, pady1 = self.parse_padding(padding) - - upW = in_width * upx + padx0 + padx1 - upH = in_height * upy + pady0 + pady1 - assert upW >= f.shape[-1] and upH >= f.shape[0] - - # Channel first format. - x = tf.transpose(x, perm=[0, 3, 1, 2]) - - # Upsample by inserting zeros. - x = tf.reshape(x, [num_channels, batch_size, in_height, 1, in_width, 1]) - x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) - x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) - - # Pad or crop. - x = tf.pad(x, [[0, 0], [0, 0], - [tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], - [tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) - x = x[:, :, tf.math.maximum(-pady0, 0) : x.shape[2] - tf.math.maximum(-pady1, 0), tf.math.maximum(-padx0, 0) : x.shape[3] - tf.math.maximum(-padx1, 0)] - - # Setup filter. - f = f * (gain ** (f.ndim / 2)) - f = tf.cast(f, dtype=x.dtype) - if not flip_filter: - f = tf.reverse(f, axis=[-1]) - f = tf.reshape(f, shape=(1, 1, f.shape[-1])) - f = tf.repeat(f, repeats=num_channels, axis=0) - - if tf.rank(f) == 4: - f_0 = tf.transpose(f, perm=[2, 3, 1, 0]) - x = tf.nn.conv2d(x, f_0, 1, 'VALID') - else: - f_0 = tf.expand_dims(f, axis=2) - f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) - - f_1 = tf.expand_dims(f, axis=3) - f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) - - x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') - x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') - - x = x[:, :, ::downy, ::downx] - - # Back to channel last. - x = tf.transpose(x, perm=[0, 2, 3, 1]) - return x - - - def filtered_lrelu(self, - x, fu=None, fd=None, b=None, - up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): - - px0, px1, py0, py1 = self.parse_padding(padding) - - x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) - x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) - x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) - - return x - - - def call(self, inputs): - x = inputs - x = self.filtered_lrelu(x, - fu=self.u_filter, - fd=self.d_filter, - b=self.bias, - up=self.u_factor, - down=self.d_factor, - padding=self.padding, - gain=self.gain, - slope=self.slope, - clamp=self.conv_clamp) - return x - - def get_config(self): - base_config = super(SynthesisLayerUpDownOnly, self).get_config() - return base_config - - -class Localization(Layer): - def __init__(self): - super(Localization, self).__init__() - - self.pool = MaxPooling2D() - self.conv_0 = Conv2D(36, 5, activation='relu') - self.conv_1 = Conv2D(36, 5, activation='relu') - self.flatten = Flatten() - self.fc_0 = Dense(36, activation='relu') - self.fc_1 = Dense(6, bias_initializer=tf.keras.initializers.constant([1.0, 0.0, 0.0, 0.0, 1.0, 0.0]), - kernel_initializer='zeros') - self.reshape = Reshape((2, 3)) - - def build(self, input_shape): - print(input_shape) - - def compute_output_shape(self, input_shape): - return [None, 6] - - def call(self, inputs): - x = self.conv_0(inputs) - x = self.pool(x) - x = self.conv_1(x) - x = self.pool(x) - x = self.flatten(x) - x = self.fc_0(x) - theta = self.fc_1(x) - theta = self.reshape(theta) - - return theta - - -class BilinearInterpolation(Layer): - def __init__(self, height=36, width=36): - super(BilinearInterpolation, self).__init__() - self.height = height - self.width = width - - def compute_output_shape(self, input_shape): - return [None, self.height, self.width, 1] - - def get_config(self): - config = { - 'height': self.height, - 'width': self.width - } - base_config = super(BilinearInterpolation, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - def advance_indexing(self, inputs, x, y): - shape = tf.shape(inputs) - batch_size = shape[0] - - batch_idx = tf.range(0, batch_size) - batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1)) - - b = tf.tile(batch_idx, (1, self.height, self.width)) - indices = tf.stack([b, y, x], 3) - - return tf.gather_nd(inputs, indices) - - def grid_generator(self, batch): - x = tf.linspace(-1, 1, self.width) - y = tf.linspace(-1, 1, self.height) - - xx, yy = tf.meshgrid(x, y) - xx = tf.reshape(xx, (-1,)) - yy = tf.reshape(yy, (-1,)) - - homogenous_coordinates = tf.stack([xx, yy, tf.ones_like(xx)]) - homogenous_coordinates = tf.expand_dims(homogenous_coordinates, axis=0) - homogenous_coordinates = tf.tile(homogenous_coordinates, [batch, 1, 1]) - homogenous_coordinates = tf.cast(homogenous_coordinates, dtype=tf.float32) - return homogenous_coordinates - - def interpolate(self, images, homogenous_coordinates, theta): - - with tf.name_scope("Transformation"): - transformed = tf.matmul(theta, homogenous_coordinates) - transformed = tf.transpose(transformed, perm=[0, 2, 1]) - transformed = tf.reshape(transformed, [-1, self.height, self.width, 2]) - - x_transformed = transformed[:, :, :, 0] - y_transformed = transformed[:, :, :, 1] - - x = ((x_transformed + 1.) * tf.cast(self.width, dtype=tf.float32)) * 0.5 - y = ((y_transformed + 1.) * tf.cast(self.height, dtype=tf.float32)) * 0.5 - - with tf.name_scope("VaribleCasting"): - x0 = tf.cast(tf.math.floor(x), dtype=tf.int32) - x1 = x0 + 1 - y0 = tf.cast(tf.math.floor(y), dtype=tf.int32) - y1 = y0 + 1 - - x0 = tf.clip_by_value(x0, 0, self.width-1) - x1 = tf.clip_by_value(x1, 0, self.width - 1) - y0 = tf.clip_by_value(y0, 0, self.height - 1) - y1 = tf.clip_by_value(y1, 0, self.height - 1) - x = tf.clip_by_value(x, 0, tf.cast(self.width, dtype=tf.float32) - 1.0) - y = tf.clip_by_value(y, 0, tf.cast(self.height, dtype=tf.float32) - 1.0) - - with tf.name_scope("AdvancedIndexing"): - i_a = self.advance_indexing(images, x0, y0) - i_b = self.advance_indexing(images, x0, y1) - i_c = self.advance_indexing(images, x1, y0) - i_d = self.advance_indexing(images, x1, y1) - - with tf.name_scope("Interpolation"): - x0 = tf.cast(x0, dtype=tf.float32) - x1 = tf.cast(x1, dtype=tf.float32) - y0 = tf.cast(y0, dtype=tf.float32) - y1 = tf.cast(y1, dtype=tf.float32) - - w_a = (x1 - x) * (y1 - y) - w_b = (x1 - x) * (y - y0) - w_c = (x - x0) * (y1 - y) - w_d = (x - x0) * (y - y0) - - w_a = tf.expand_dims(w_a, axis=3) - w_b = tf.expand_dims(w_b, axis=3) - w_c = tf.expand_dims(w_c, axis=3) - w_d = tf.expand_dims(w_d, axis=3) - - return tf.math.add_n([w_a * i_a + w_b * i_b + w_c * i_c + w_d * i_d]) - - def call(self, inputs): - images, theta = inputs - homogenous_coordinates = self.grid_generator(batch=tf.shape(images)[0]) - return self.interpolate(images, homogenous_coordinates, theta) - - -class ResBlockLR(Layer): - def __init__(self, filters=16): - super(ResBlockLR, self).__init__() - self.filters = filters - - self.conv_0 = Conv2D(filters=filters, - kernel_size=3, - strides=1, - padding='same') - self.bn_0 = BatchNormalization() - self.conv_1 = Conv2D(filters=filters, - kernel_size=3, - strides=1, - padding='same') - self.bn_1 = BatchNormalization() - - def get_config(self): - config = { - 'filters': self.filters, - } - base_config = super(ResBlockLR, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - def call(self, inputs): - x = self.conv_0(inputs) - x = self.bn_0(x) - x = tf.nn.leaky_relu(x, alpha=0.2) - x = self.conv_1(x) - x = self.bn_1(x) - return x + inputs - - -class LearnedResize(Layer): - def __init__(self, width, height, filters=16, in_channels=3, num_res_block=3, interpolation='bilinear'): - super(LearnedResize, self).__init__() - self.filters = filters - self.num_res_block = num_res_block - self.interpolation = interpolation - self.in_channels = in_channels - self.width = width - self.height = height - - self.resize_layer = tf.keras.layers.experimental.preprocessing.Resizing(height, - width, - interpolation=interpolation) - - self.init_layers = tf.keras.models.Sequential([Conv2D(filters=filters, - kernel_size=7, - strides=1, - padding='same'), - LeakyReLU(0.2), - Conv2D(filters=filters, - kernel_size=1, - strides=1, - padding='same'), - LeakyReLU(0.2), - BatchNormalization() - ]) - res_blocks = [] - for i in range(num_res_block): - res_blocks.append(ResBlockLR(filters=filters)) - res_blocks.append(Conv2D(filters=filters, - kernel_size=3, - strides=1, - padding='same', - use_bias=False)) - res_blocks.append(BatchNormalization()) - self.res_block_pipe = tf.keras.models.Sequential(res_blocks) - self.final_conv = Conv2D(filters=in_channels, - kernel_size=3, - strides=1, - padding='same') - - - def compute_output_shape(self, input_shape): - return [None, self.target_size[0], self.target_size[1], input_shape[-1]] - - def get_config(self): - config = { - 'filters': self.filters, - 'num_res_block': self.num_res_block, - 'interpolation': self.interpolation, - 'width': self.width, - 'height': self.height, - } - base_config = super(LearnedResize, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - def call(self, inputs): - x_l = self.init_layers(inputs) - x_l_0 = self.resize_layer(x_l) - x_l = self.res_block_pipe(x_l_0) - x_l = x_l + x_l_0 - x_l = self.final_conv(x_l) - - x = self.resize_layer(inputs) - - return x + x_l