import os import numpy as np from PIL import Image from datetime import datetime import tensorflow as tf from tensorflow import keras from tensorflow.keras import optimizers, mixed_precision, Model from wandb.keras import WandbCallback from .dce_net import build_dce_net from .dataloader import UnpairedLowLightDataset from .losses import ( color_constancy_loss, exposure_loss, illumination_smoothness_loss, SpatialConsistencyLoss, ) from ..commons import ( download_lol_dataset, download_unpaired_low_light_dataset, init_wandb, ) class ZeroDCE(Model): def __init__( self, experiment_name=None, wandb_api_key=None, use_mixed_precision: bool = False, **kwargs ): super(ZeroDCE, self).__init__(**kwargs) self.experiment_name = experiment_name if use_mixed_precision: policy = mixed_precision.Policy("mixed_float16") mixed_precision.set_global_policy(policy) if wandb_api_key is not None: init_wandb("zero-dce", experiment_name, wandb_api_key) self.using_wandb = True else: self.using_wandb = False self.dce_model = build_dce_net() def compile(self, learning_rate, **kwargs): super(ZeroDCE, self).compile(**kwargs) self.optimizer = optimizers.Adam(learning_rate=learning_rate) self.spatial_constancy_loss = SpatialConsistencyLoss(reduction="none") def get_enhanced_image(self, data, output): r1 = output[:, :, :, :3] r2 = output[:, :, :, 3:6] r3 = output[:, :, :, 6:9] r4 = output[:, :, :, 9:12] r5 = output[:, :, :, 12:15] r6 = output[:, :, :, 15:18] r7 = output[:, :, :, 18:21] r8 = output[:, :, :, 21:24] x = data + r1 * (tf.square(data) - data) x = x + r2 * (tf.square(x) - x) x = x + r3 * (tf.square(x) - x) enhanced_image = x + r4 * (tf.square(x) - x) x = enhanced_image + r5 * (tf.square(enhanced_image) - enhanced_image) x = x + r6 * (tf.square(x) - x) x = x + r7 * (tf.square(x) - x) enhanced_image = x + r8 * (tf.square(x) - x) return enhanced_image def call(self, data): dce_net_output = self.dce_model(data) return self.get_enhanced_image(data, dce_net_output) def compute_losses(self, data, output): enhanced_image = self.get_enhanced_image(data, output) loss_illumination = 200 * illumination_smoothness_loss(output) loss_spatial_constancy = tf.reduce_mean( self.spatial_constancy_loss(enhanced_image, data) ) loss_color_constancy = 5 * tf.reduce_mean(color_constancy_loss(enhanced_image)) loss_exposure = 10 * tf.reduce_mean(exposure_loss(enhanced_image)) total_loss = ( loss_illumination + loss_spatial_constancy + loss_color_constancy + loss_exposure ) return { "total_loss": total_loss, "illumination_smoothness_loss": loss_illumination, "spatial_constancy_loss": loss_spatial_constancy, "color_constancy_loss": loss_color_constancy, "exposure_loss": loss_exposure, } def train_step(self, data): with tf.GradientTape() as tape: output = self.dce_model(data) losses = self.compute_losses(data, output) gradients = tape.gradient( losses["total_loss"], self.dce_model.trainable_weights ) self.optimizer.apply_gradients(zip(gradients, self.dce_model.trainable_weights)) return losses def test_step(self, data): output = self.dce_model(data) return self.compute_losses(data, output) def save_weights(self, filepath, overwrite=True, save_format=None, options=None): """While saving the weights, we simply save the weights of the DCE-Net""" self.dce_model.save_weights( filepath, overwrite=overwrite, save_format=save_format, options=options ) def load_weights(self, filepath, by_name=False, skip_mismatch=False, options=None): """While loading the weights, we simply load the weights of the DCE-Net""" self.dce_model.load_weights( filepath=filepath, by_name=by_name, skip_mismatch=skip_mismatch, options=options, ) def build_datasets( self, image_size: int = 256, dataset_label: str = "lol", apply_resize: bool = False, apply_random_horizontal_flip: bool = True, apply_random_vertical_flip: bool = True, apply_random_rotation: bool = True, val_split: float = 0.2, batch_size: int = 16, ) -> None: if dataset_label == "lol": (self.low_images, _), (self.test_low_images, _) = download_lol_dataset() elif dataset_label == "unpaired": self.low_images, ( self.test_low_images, _, ) = download_unpaired_low_light_dataset() data_loader = UnpairedLowLightDataset( image_size, apply_resize, apply_random_horizontal_flip, apply_random_vertical_flip, apply_random_rotation, ) self.train_dataset, self.val_dataset = data_loader.get_datasets( self.low_images, val_split, batch_size ) def train(self, epochs: int): log_dir = os.path.join( self.experiment_name, "logs", datetime.now().strftime("%Y%m%d-%H%M%S"), ) tensorboard_callback = keras.callbacks.TensorBoard(log_dir, histogram_freq=1) callbacks = [tensorboard_callback] if self.using_wandb: callbacks += [WandbCallback()] history = self.fit( self.train_dataset, validation_data=self.val_dataset, epochs=epochs, callbacks=callbacks, ) return history def infer(self, original_image): image = keras.preprocessing.image.img_to_array(original_image) image = image.astype("float32") / 255.0 image = np.expand_dims(image, axis=0) output_image = self.call(image) output_image = tf.cast((output_image[0, :, :, :] * 255), dtype=np.uint8) output_image = Image.fromarray(output_image.numpy()) return output_image def infer_from_file(self, original_image_file: str): original_image = Image.open(original_image_file) return self.infer(original_image)