Spaces:
Runtime error
Runtime error
File size: 1,309 Bytes
659a217 0f84baa 659a217 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import tensorflow as tf
from .spatial_constancy import SpatialConsistencyLoss
def color_constancy_loss(x):
mean_rgb = tf.reduce_mean(x, axis=(1, 2), keepdims=True)
mean_r, mean_g, mean_b = (
mean_rgb[:, :, :, 0],
mean_rgb[:, :, :, 1],
mean_rgb[:, :, :, 2],
)
diff_rg = tf.square(mean_r - mean_g)
diff_rb = tf.square(mean_r - mean_b)
diff_gb = tf.square(mean_b - mean_g)
return tf.sqrt(tf.square(diff_rg) + tf.square(diff_rb) + tf.square(diff_gb))
def exposure_loss(x, mean_val=0.6):
x = tf.reduce_mean(x, axis=3, keepdims=True)
mean = tf.nn.avg_pool2d(x, ksize=16, strides=16, padding="VALID")
return tf.reduce_mean(tf.square(mean - mean_val))
def illumination_smoothness_loss(x):
batch_size = tf.shape(x)[0]
h_x = tf.shape(x)[1]
w_x = tf.shape(x)[2]
count_h = (tf.shape(x)[2] - 1) * tf.shape(x)[3]
count_w = tf.shape(x)[2] * (tf.shape(x)[3] - 1)
h_tv = tf.reduce_sum(tf.square((x[:, 1:, :, :] - x[:, : h_x - 1, :, :])))
w_tv = tf.reduce_sum(tf.square((x[:, :, 1:, :] - x[:, :, : w_x - 1, :])))
batch_size = tf.cast(batch_size, dtype=tf.float32)
count_h = tf.cast(count_h, dtype=tf.float32)
count_w = tf.cast(count_w, dtype=tf.float32)
return 2 * (h_tv / count_h + w_tv / count_w) / batch_size
|