face-swap / retinaface /models.py
yotamsapi's picture
Create models.py
3fc64a0
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.applications import MobileNetV2, ResNet50
from tensorflow.keras.layers import Input, Conv2D, ReLU, LeakyReLU
from retinaface.anchor import decode_tf, prior_box_tf
def _regularizer(weights_decay):
"""l2 regularizer"""
return tf.keras.regularizers.l2(weights_decay)
def _kernel_init(scale=1.0, seed=None):
"""He normal initializer"""
return tf.keras.initializers.he_normal()
class BatchNormalization(tf.keras.layers.BatchNormalization):
"""Make trainable=False freeze BN for real (the og version is sad).
ref: https://github.com/zzh8829/yolov3-tf2
"""
def __init__(self, axis=-1, momentum=0.9, epsilon=1e-5, center=True,
scale=True, name=None, **kwargs):
super(BatchNormalization, self).__init__(
axis=axis, momentum=momentum, epsilon=epsilon, center=center,
scale=scale, name=name, **kwargs)
def call(self, x, training=False):
if training is None:
training = tf.constant(False)
training = tf.logical_and(training, self.trainable)
return super().call(x, training)
def Backbone(backbone_type='ResNet50', use_pretrain=True):
"""Backbone Model"""
weights = None
if use_pretrain:
weights = 'imagenet'
def backbone(x):
if backbone_type == 'ResNet50':
extractor = ResNet50(
input_shape=x.shape[1:], include_top=False, weights=weights)
pick_layer1 = 80 # [80, 80, 512]
pick_layer2 = 142 # [40, 40, 1024]
pick_layer3 = 174 # [20, 20, 2048]
preprocess = tf.keras.applications.resnet.preprocess_input
elif backbone_type == 'MobileNetV2':
extractor = MobileNetV2(
input_shape=x.shape[1:], include_top=False, weights=weights)
pick_layer1 = 54 # [80, 80, 32]
pick_layer2 = 116 # [40, 40, 96]
pick_layer3 = 143 # [20, 20, 160]
preprocess = tf.keras.applications.mobilenet_v2.preprocess_input
else:
raise NotImplementedError(
'Backbone type {} is not recognized.'.format(backbone_type))
return Model(extractor.input,
(extractor.layers[pick_layer1].output,
extractor.layers[pick_layer2].output,
extractor.layers[pick_layer3].output),
name=backbone_type + '_extrator')(preprocess(x))
return backbone
class ConvUnit(tf.keras.layers.Layer):
"""Conv + BN + Act"""
def __init__(self, f, k, s, wd, act=None, **kwargs):
super(ConvUnit, self).__init__(**kwargs)
self.conv = Conv2D(filters=f, kernel_size=k, strides=s, padding='same',
kernel_initializer=_kernel_init(),
kernel_regularizer=_regularizer(wd),
use_bias=False)
self.bn = BatchNormalization()
if act is None:
self.act_fn = tf.identity
elif act == 'relu':
self.act_fn = ReLU()
elif act == 'lrelu':
self.act_fn = LeakyReLU(0.1)
else:
raise NotImplementedError(
'Activation function type {} is not recognized.'.format(act))
def call(self, x):
return self.act_fn(self.bn(self.conv(x)))
class FPN(tf.keras.layers.Layer):
"""Feature Pyramid Network"""
def __init__(self, out_ch, wd, **kwargs):
super(FPN, self).__init__(**kwargs)
act = 'relu'
self.out_ch = out_ch
self.wd = wd
if (out_ch <= 64):
act = 'lrelu'
self.output1 = ConvUnit(f=out_ch, k=1, s=1, wd=wd, act=act)
self.output2 = ConvUnit(f=out_ch, k=1, s=1, wd=wd, act=act)
self.output3 = ConvUnit(f=out_ch, k=1, s=1, wd=wd, act=act)
self.merge1 = ConvUnit(f=out_ch, k=3, s=1, wd=wd, act=act)
self.merge2 = ConvUnit(f=out_ch, k=3, s=1, wd=wd, act=act)
def call(self, x):
output1 = self.output1(x[0]) # [80, 80, out_ch]
output2 = self.output2(x[1]) # [40, 40, out_ch]
output3 = self.output3(x[2]) # [20, 20, out_ch]
up_h, up_w = tf.shape(output2)[1], tf.shape(output2)[2]
up3 = tf.image.resize(output3, [up_h, up_w], method='nearest')
output2 = output2 + up3
output2 = self.merge2(output2)
up_h, up_w = tf.shape(output1)[1], tf.shape(output1)[2]
up2 = tf.image.resize(output2, [up_h, up_w], method='nearest')
output1 = output1 + up2
output1 = self.merge1(output1)
return output1, output2, output3
def get_config(self):
config = {
'out_ch': self.out_ch,
'wd': self.wd,
}
base_config = super(FPN, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class SSH(tf.keras.layers.Layer):
"""Single Stage Headless Layer"""
def __init__(self, out_ch, wd, **kwargs):
super(SSH, self).__init__(**kwargs)
assert out_ch % 4 == 0
self.out_ch = out_ch
self.wd = wd
act = 'relu'
if (out_ch <= 64):
act = 'lrelu'
self.conv_3x3 = ConvUnit(f=out_ch // 2, k=3, s=1, wd=wd, act=None)
self.conv_5x5_1 = ConvUnit(f=out_ch // 4, k=3, s=1, wd=wd, act=act)
self.conv_5x5_2 = ConvUnit(f=out_ch // 4, k=3, s=1, wd=wd, act=None)
self.conv_7x7_2 = ConvUnit(f=out_ch // 4, k=3, s=1, wd=wd, act=act)
self.conv_7x7_3 = ConvUnit(f=out_ch // 4, k=3, s=1, wd=wd, act=None)
self.relu = ReLU()
def call(self, x):
conv_3x3 = self.conv_3x3(x)
conv_5x5_1 = self.conv_5x5_1(x)
conv_5x5 = self.conv_5x5_2(conv_5x5_1)
conv_7x7_2 = self.conv_7x7_2(conv_5x5_1)
conv_7x7 = self.conv_7x7_3(conv_7x7_2)
output = tf.concat([conv_3x3, conv_5x5, conv_7x7], axis=3)
output = self.relu(output)
return output
def get_config(self):
config = {
'out_ch': self.out_ch,
'wd': self.wd,
}
base_config = super(SSH, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class BboxHead(tf.keras.layers.Layer):
"""Bbox Head Layer"""
def __init__(self, num_anchor, wd, **kwargs):
super(BboxHead, self).__init__(**kwargs)
self.num_anchor = num_anchor
self.wd = wd
self.conv = Conv2D(filters=num_anchor * 4, kernel_size=1, strides=1)
def call(self, x):
h, w = tf.shape(x)[1], tf.shape(x)[2]
x = self.conv(x)
return tf.reshape(x, [-1, h * w * self.num_anchor, 4])
def get_config(self):
config = {
'num_anchor': self.num_anchor,
'wd': self.wd,
}
base_config = super(BboxHead, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class LandmarkHead(tf.keras.layers.Layer):
"""Landmark Head Layer"""
def __init__(self, num_anchor, wd, name='LandmarkHead', **kwargs):
super(LandmarkHead, self).__init__(name=name, **kwargs)
self.num_anchor = num_anchor
self.wd = wd
self.conv = Conv2D(filters=num_anchor * 10, kernel_size=1, strides=1)
def call(self, x):
h, w = tf.shape(x)[1], tf.shape(x)[2]
x = self.conv(x)
return tf.reshape(x, [-1, h * w * self.num_anchor, 10])
def get_config(self):
config = {
'num_anchor': self.num_anchor,
'wd': self.wd,
}
base_config = super(LandmarkHead, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class ClassHead(tf.keras.layers.Layer):
"""Class Head Layer"""
def __init__(self, num_anchor, wd, name='ClassHead', **kwargs):
super(ClassHead, self).__init__(name=name, **kwargs)
self.num_anchor = num_anchor
self.wd = wd
self.conv = Conv2D(filters=num_anchor * 2, kernel_size=1, strides=1)
def call(self, x):
h, w = tf.shape(x)[1], tf.shape(x)[2]
x = self.conv(x)
return tf.reshape(x, [-1, h * w * self.num_anchor, 2])
def get_config(self):
config = {
'num_anchor': self.num_anchor,
'wd': self.wd,
}
base_config = super(ClassHead, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def RetinaFaceModel(cfg, training=False, iou_th=0.4, score_th=0.02,
name='RetinaFaceModel'):
"""Retina Face Model"""
input_size = cfg['input_size'] if training else None
wd = cfg['weights_decay']
out_ch = cfg['out_channel']
num_anchor = len(cfg['min_sizes'][0])
backbone_type = cfg['backbone_type']
# define model
x = inputs = Input([input_size, input_size, 3], name='input_image')
x = Backbone(backbone_type=backbone_type)(x)
fpn = FPN(out_ch=out_ch, wd=wd)(x)
features = [SSH(out_ch=out_ch, wd=wd)(f)
for i, f in enumerate(fpn)]
bbox_regressions = tf.concat(
[BboxHead(num_anchor, wd=wd)(f)
for i, f in enumerate(features)], axis=1)
landm_regressions = tf.concat(
[LandmarkHead(num_anchor, wd=wd, name=f'LandmarkHead_{i}')(f)
for i, f in enumerate(features)], axis=1)
classifications = tf.concat(
[ClassHead(num_anchor, wd=wd, name=f'ClassHead_{i}')(f)
for i, f in enumerate(features)], axis=1)
classifications = tf.keras.layers.Softmax(axis=-1)(classifications)
if training:
out = (bbox_regressions, landm_regressions, classifications)
else:
# only for batch size 1
preds = tf.concat( # [bboxes, landms, landms_valid, conf]
[bbox_regressions[0],
landm_regressions[0],
tf.ones_like(classifications[0, :, 0][..., tf.newaxis]),
classifications[0, :, 1][..., tf.newaxis]], 1)
priors = prior_box_tf((tf.shape(inputs)[1], tf.shape(inputs)[2]), cfg['min_sizes'], cfg['steps'], cfg['clip'])
decode_preds = decode_tf(preds, priors, cfg['variances'])
selected_indices = tf.image.non_max_suppression(
boxes=decode_preds[:, :4],
scores=decode_preds[:, -1],
max_output_size=tf.shape(decode_preds)[0],
iou_threshold=iou_th,
score_threshold=score_th)
out = tf.gather(decode_preds, selected_indices)
return Model(inputs, out, name=name), Model(inputs, [bbox_regressions, landm_regressions, classifications], name=name + '_bb_only')