enhance-me / enhance_me /zero_dce /losses /spatial_constancy.py
geekyrakshit's picture
added loss functions
659a217
import tensorflow as tf
from tensorflow.keras import losses
class SpatialConsistencyLoss(losses.Loss):
def __init__(self, **kwargs):
super(SpatialConsistencyLoss, self).__init__(reduction="none")
self.left_kernel = tf.constant(
[[[[0, 0, 0]], [[-1, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32
)
self.right_kernel = tf.constant(
[[[[0, 0, 0]], [[0, 1, -1]], [[0, 0, 0]]]], dtype=tf.float32
)
self.up_kernel = tf.constant(
[[[[0, -1, 0]], [[0, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32
)
self.down_kernel = tf.constant(
[[[[0, 0, 0]], [[0, 1, 0]], [[0, -1, 0]]]], dtype=tf.float32
)
def call(self, y_true, y_pred):
original_mean = tf.reduce_mean(y_true, 3, keepdims=True)
enhanced_mean = tf.reduce_mean(y_pred, 3, keepdims=True)
original_pool = tf.nn.avg_pool2d(
original_mean, ksize=4, strides=4, padding="VALID"
)
enhanced_pool = tf.nn.avg_pool2d(
enhanced_mean, ksize=4, strides=4, padding="VALID"
)
d_original_left = tf.nn.conv2d(
original_pool, self.left_kernel, strides=[1, 1, 1, 1], padding="SAME"
)
d_original_right = tf.nn.conv2d(
original_pool, self.right_kernel, strides=[1, 1, 1, 1], padding="SAME"
)
d_original_up = tf.nn.conv2d(
original_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME"
)
d_original_down = tf.nn.conv2d(
original_pool, self.down_kernel, strides=[1, 1, 1, 1], padding="SAME"
)
d_enhanced_left = tf.nn.conv2d(
enhanced_pool, self.left_kernel, strides=[1, 1, 1, 1], padding="SAME"
)
d_enhanced_right = tf.nn.conv2d(
enhanced_pool, self.right_kernel, strides=[1, 1, 1, 1], padding="SAME"
)
d_enhanced_up = tf.nn.conv2d(
enhanced_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME"
)
d_enhanced_down = tf.nn.conv2d(
enhanced_pool, self.down_kernel, strides=[1, 1, 1, 1], padding="SAME"
)
d_left = tf.square(d_original_left - d_enhanced_left)
d_right = tf.square(d_original_right - d_enhanced_right)
d_up = tf.square(d_original_up - d_enhanced_up)
d_down = tf.square(d_original_down - d_enhanced_down)
return d_left + d_right + d_up + d_down