import tensorflow as tf from tensorflow.keras import losses class SpatialConsistencyLoss(losses.Loss): def __init__(self, **kwargs): super(SpatialConsistencyLoss, self).__init__(reduction="none") self.left_kernel = tf.constant( [[[[0, 0, 0]], [[-1, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32 ) self.right_kernel = tf.constant( [[[[0, 0, 0]], [[0, 1, -1]], [[0, 0, 0]]]], dtype=tf.float32 ) self.up_kernel = tf.constant( [[[[0, -1, 0]], [[0, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32 ) self.down_kernel = tf.constant( [[[[0, 0, 0]], [[0, 1, 0]], [[0, -1, 0]]]], dtype=tf.float32 ) def call(self, y_true, y_pred): original_mean = tf.reduce_mean(y_true, 3, keepdims=True) enhanced_mean = tf.reduce_mean(y_pred, 3, keepdims=True) original_pool = tf.nn.avg_pool2d( original_mean, ksize=4, strides=4, padding="VALID" ) enhanced_pool = tf.nn.avg_pool2d( enhanced_mean, ksize=4, strides=4, padding="VALID" ) d_original_left = tf.nn.conv2d( original_pool, self.left_kernel, strides=[1, 1, 1, 1], padding="SAME" ) d_original_right = tf.nn.conv2d( original_pool, self.right_kernel, strides=[1, 1, 1, 1], padding="SAME" ) d_original_up = tf.nn.conv2d( original_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME" ) d_original_down = tf.nn.conv2d( original_pool, self.down_kernel, strides=[1, 1, 1, 1], padding="SAME" ) d_enhanced_left = tf.nn.conv2d( enhanced_pool, self.left_kernel, strides=[1, 1, 1, 1], padding="SAME" ) d_enhanced_right = tf.nn.conv2d( enhanced_pool, self.right_kernel, strides=[1, 1, 1, 1], padding="SAME" ) d_enhanced_up = tf.nn.conv2d( enhanced_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME" ) d_enhanced_down = tf.nn.conv2d( enhanced_pool, self.down_kernel, strides=[1, 1, 1, 1], padding="SAME" ) d_left = tf.square(d_original_left - d_enhanced_left) d_right = tf.square(d_original_right - d_enhanced_right) d_up = tf.square(d_original_up - d_enhanced_up) d_down = tf.square(d_original_down - d_enhanced_down) return d_left + d_right + d_up + d_down