Spaces:
Runtime error
Runtime error
import tensorflow as tf | |
from .spatial_constancy import SpatialConsistencyLoss | |
def color_constancy_loss(x): | |
mean_rgb = tf.reduce_mean(x, axis=(1, 2), keepdims=True) | |
mean_r, mean_g, mean_b = ( | |
mean_rgb[:, :, :, 0], | |
mean_rgb[:, :, :, 1], | |
mean_rgb[:, :, :, 2], | |
) | |
diff_rg = tf.square(mean_r - mean_g) | |
diff_rb = tf.square(mean_r - mean_b) | |
diff_gb = tf.square(mean_b - mean_g) | |
return tf.sqrt(tf.square(diff_rg) + tf.square(diff_rb) + tf.square(diff_gb)) | |
def exposure_loss(x, mean_val=0.6): | |
x = tf.reduce_mean(x, axis=3, keepdims=True) | |
mean = tf.nn.avg_pool2d(x, ksize=16, strides=16, padding="VALID") | |
return tf.reduce_mean(tf.square(mean - mean_val)) | |
def illumination_smoothness_loss(x): | |
batch_size = tf.shape(x)[0] | |
h_x = tf.shape(x)[1] | |
w_x = tf.shape(x)[2] | |
count_h = (tf.shape(x)[2] - 1) * tf.shape(x)[3] | |
count_w = tf.shape(x)[2] * (tf.shape(x)[3] - 1) | |
h_tv = tf.reduce_sum(tf.square((x[:, 1:, :, :] - x[:, : h_x - 1, :, :]))) | |
w_tv = tf.reduce_sum(tf.square((x[:, :, 1:, :] - x[:, :, : w_x - 1, :]))) | |
batch_size = tf.cast(batch_size, dtype=tf.float32) | |
count_h = tf.cast(count_h, dtype=tf.float32) | |
count_w = tf.cast(count_w, dtype=tf.float32) | |
return 2 * (h_tv / count_h + w_tv / count_w) / batch_size | |