Spaces:
Runtime error
Runtime error
File size: 2,448 Bytes
659a217 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import tensorflow as tf
from tensorflow.keras import losses
class SpatialConsistencyLoss(losses.Loss):
def __init__(self, **kwargs):
super(SpatialConsistencyLoss, self).__init__(reduction="none")
self.left_kernel = tf.constant(
[[[[0, 0, 0]], [[-1, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32
)
self.right_kernel = tf.constant(
[[[[0, 0, 0]], [[0, 1, -1]], [[0, 0, 0]]]], dtype=tf.float32
)
self.up_kernel = tf.constant(
[[[[0, -1, 0]], [[0, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32
)
self.down_kernel = tf.constant(
[[[[0, 0, 0]], [[0, 1, 0]], [[0, -1, 0]]]], dtype=tf.float32
)
def call(self, y_true, y_pred):
original_mean = tf.reduce_mean(y_true, 3, keepdims=True)
enhanced_mean = tf.reduce_mean(y_pred, 3, keepdims=True)
original_pool = tf.nn.avg_pool2d(
original_mean, ksize=4, strides=4, padding="VALID"
)
enhanced_pool = tf.nn.avg_pool2d(
enhanced_mean, ksize=4, strides=4, padding="VALID"
)
d_original_left = tf.nn.conv2d(
original_pool, self.left_kernel, strides=[1, 1, 1, 1], padding="SAME"
)
d_original_right = tf.nn.conv2d(
original_pool, self.right_kernel, strides=[1, 1, 1, 1], padding="SAME"
)
d_original_up = tf.nn.conv2d(
original_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME"
)
d_original_down = tf.nn.conv2d(
original_pool, self.down_kernel, strides=[1, 1, 1, 1], padding="SAME"
)
d_enhanced_left = tf.nn.conv2d(
enhanced_pool, self.left_kernel, strides=[1, 1, 1, 1], padding="SAME"
)
d_enhanced_right = tf.nn.conv2d(
enhanced_pool, self.right_kernel, strides=[1, 1, 1, 1], padding="SAME"
)
d_enhanced_up = tf.nn.conv2d(
enhanced_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME"
)
d_enhanced_down = tf.nn.conv2d(
enhanced_pool, self.down_kernel, strides=[1, 1, 1, 1], padding="SAME"
)
d_left = tf.square(d_original_left - d_enhanced_left)
d_right = tf.square(d_original_right - d_enhanced_right)
d_up = tf.square(d_original_up - d_enhanced_up)
d_down = tf.square(d_original_down - d_enhanced_down)
return d_left + d_right + d_up + d_down
|