geekyrakshit commited on
Commit
643b5b3
·
unverified ·
2 Parent(s): dac55f5 a89beff

Merge pull request #1 from soumik12345/mirnet

Browse files
.gitignore CHANGED
@@ -127,3 +127,8 @@ dmypy.json
127
 
128
  # Pyre type checker
129
  .pyre/
 
 
 
 
 
 
127
 
128
  # Pyre type checker
129
  .pyre/
130
+
131
+ # Datasets
132
+ datasets/
133
+ **.zip
134
+ **.h5
README.md CHANGED
@@ -1 +1,11 @@
 
 
 
 
 
 
 
 
 
 
1
  # enhance-me
 
1
+ ---
2
+ title: Enhance Me
3
+ emoji: 🌖
4
+ colorFrom: pink
5
+ colorTo: pink
6
+ sdk: streamlit
7
+ app_file: app.py
8
+ pinned: false
9
+ ---
10
+
11
  # enhance-me
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import streamlit as st
3
+ from tensorflow.keras import utils
4
+
5
+ from enhance_me.mirnet import MIRNet
6
+
7
+
8
+ @st.cache
9
+ def get_mirnet_object() -> MIRNet:
10
+ mirnet = MIRNet()
11
+ mirnet.build_model()
12
+ utils.get_file(
13
+ "weights_lol_128.h5",
14
+ "https://github.com/soumik12345/enhance-me/releases/download/v0.2/weights_lol_128.h5",
15
+ cache_dir=".",
16
+ cache_subdir="weights",
17
+ )
18
+ mirnet.load_weights("./weights/weights_lol_128.h5")
19
+ return mirnet
20
+
21
+
22
+ def main():
23
+ st.markdown("# Enhance Me")
24
+ st.markdown("Made with :heart: by [geekyRakshit](http://github.com/soumik12345)")
25
+ application = st.sidebar.selectbox(
26
+ "Please select the application:", ("", "Low-light enhancement")
27
+ )
28
+ if application != "":
29
+ if application == "Low-light enhancement":
30
+ uploaded_file = st.sidebar.file_uploader("Select your image:")
31
+ if uploaded_file is not None:
32
+ original_image = Image.open(uploaded_file)
33
+ st.image(original_image, caption="original image")
34
+ mirnet = get_mirnet_object()
35
+ enhanced_image = mirnet.infer(original_image)
36
+ st.image(enhanced_image, caption="enhanced image")
37
+
38
+
39
+ if __name__ == "__main__":
40
+ main()
enhance_me/__init__.py ADDED
File without changes
enhance_me/augmentation.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+
4
+ class AugmentationFactory:
5
+ def __init__(self, image_size) -> None:
6
+ self.image_size = image_size
7
+
8
+ def random_crop(self, input_image, enhanced_image):
9
+ input_image_shape = tf.shape(input_image)[:2]
10
+ low_w = tf.random.uniform(
11
+ shape=(), maxval=input_image_shape[1] - self.image_size + 1, dtype=tf.int32
12
+ )
13
+ low_h = tf.random.uniform(
14
+ shape=(), maxval=input_image_shape[0] - self.image_size + 1, dtype=tf.int32
15
+ )
16
+ enhanced_w = low_w
17
+ enhanced_h = low_h
18
+ input_image_cropped = input_image[
19
+ low_h : low_h + self.image_size, low_w : low_w + self.image_size
20
+ ]
21
+ enhanced_image_cropped = enhanced_image[
22
+ enhanced_h : enhanced_h + self.image_size,
23
+ enhanced_w : enhanced_w + self.image_size,
24
+ ]
25
+ return input_image_cropped, enhanced_image_cropped
26
+
27
+ def random_horizontal_flip(sefl, input_image, enhanced_image):
28
+ return tf.cond(
29
+ tf.random.uniform(shape=(), maxval=1) < 0.5,
30
+ lambda: (input_image, enhanced_image),
31
+ lambda: (
32
+ tf.image.flip_left_right(input_image),
33
+ tf.image.flip_left_right(enhanced_image),
34
+ ),
35
+ )
36
+
37
+ def random_vertical_flip(self, input_image, enhanced_image):
38
+ return tf.cond(
39
+ tf.random.uniform(shape=(), maxval=1) < 0.5,
40
+ lambda: (input_image, enhanced_image),
41
+ lambda: (
42
+ tf.image.flip_up_down(input_image),
43
+ tf.image.flip_up_down(enhanced_image),
44
+ ),
45
+ )
46
+
47
+ def random_rotate(self, input_image, enhanced_image):
48
+ condition = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)
49
+ return tf.image.rot90(input_image, condition), tf.image.rot90(
50
+ enhanced_image, condition
51
+ )
enhance_me/commons.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import wandb
3
+ from glob import glob
4
+ import matplotlib.pyplot as plt
5
+
6
+ import tensorflow as tf
7
+ from tensorflow.keras import utils
8
+
9
+
10
+ def read_image(image_path):
11
+ image = tf.io.read_file(image_path)
12
+ image = tf.image.decode_png(image, channels=3)
13
+ image.set_shape([None, None, 3])
14
+ image = tf.cast(image, dtype=tf.float32) / 255.0
15
+ return image
16
+
17
+
18
+ def peak_signal_noise_ratio(y_true, y_pred):
19
+ return tf.image.psnr(y_pred, y_true, max_val=255.0)
20
+
21
+
22
+ def plot_results(images, titles, figure_size=(12, 12)):
23
+ fig = plt.figure(figsize=figure_size)
24
+ for i in range(len(images)):
25
+ fig.add_subplot(1, len(images), i + 1).set_title(titles[i])
26
+ _ = plt.imshow(images[i])
27
+ plt.axis("off")
28
+ plt.show()
29
+
30
+
31
+ def closest_number(n, m):
32
+ q = int(n / m)
33
+ n1 = m * q
34
+ if (n * m) > 0:
35
+ n2 = m * (q + 1)
36
+ else:
37
+ n2 = m * (q - 1)
38
+ if abs(n - n1) < abs(n - n2):
39
+ return n1
40
+ return n2
41
+
42
+
43
+ def init_wandb(project_name, experiment_name, wandb_api_key):
44
+ if project_name is not None and experiment_name is not None:
45
+ os.environ["WANDB_API_KEY"] = wandb_api_key
46
+ wandb.init(project=project_name, name=experiment_name, sync_tensorboard=True)
47
+
48
+
49
+ def download_lol_dataset():
50
+ utils.get_file(
51
+ "lol_dataset.zip",
52
+ "https://github.com/soumik12345/enhance-me/releases/download/v0.1/lol_dataset.zip",
53
+ cache_dir="./",
54
+ cache_subdir="./datasets",
55
+ extract=True,
56
+ )
57
+ low_images = sorted(glob("./datasets/lol_dataset/our485/low/*"))
58
+ enhanced_images = sorted(glob("./datasets/lol_dataset/our485/high/*"))
59
+ assert len(low_images) == len(enhanced_images)
60
+ test_low_images = sorted(glob("./datasets/lol_dataset/eval15/low/*"))
61
+ test_enhanced_images = sorted(glob("./datasets/lol_dataset/eval15/high/*"))
62
+ assert len(test_low_images) == len(test_enhanced_images)
63
+ return (low_images, enhanced_images), (test_low_images, test_enhanced_images)
enhance_me/mirnet/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .mirnet import MIRNet
enhance_me/mirnet/dataloader.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from typing import List
3
+
4
+ from ..commons import read_image
5
+ from ..augmentation import AugmentationFactory
6
+
7
+
8
+ class LowLightDataset:
9
+ def __init__(
10
+ self,
11
+ image_size: int = 256,
12
+ apply_random_horizontal_flip: bool = True,
13
+ apply_random_vertical_flip: bool = True,
14
+ apply_random_rotation: bool = True,
15
+ ) -> None:
16
+ self.augmentation_factory = AugmentationFactory(image_size=image_size)
17
+ self.apply_random_horizontal_flip = apply_random_horizontal_flip
18
+ self.apply_random_vertical_flip = apply_random_vertical_flip
19
+ self.apply_random_rotation = apply_random_rotation
20
+
21
+ def load_data(self, low_light_image_path, enhanced_image_path):
22
+ low_light_image = read_image(low_light_image_path)
23
+ enhanced_image = read_image(enhanced_image_path)
24
+ low_light_image, enhanced_image = self.augmentation_factory.random_crop(
25
+ low_light_image, enhanced_image
26
+ )
27
+ return low_light_image, enhanced_image
28
+
29
+ def _get_dataset(
30
+ self,
31
+ low_light_images: List[str],
32
+ enhanced_images: List[str],
33
+ batch_size: int = 16,
34
+ is_train: bool = True,
35
+ ):
36
+ dataset = tf.data.Dataset.from_tensor_slices(
37
+ (low_light_images, enhanced_images)
38
+ )
39
+ dataset = dataset.map(self.load_data, num_parallel_calls=tf.data.AUTOTUNE)
40
+ dataset = dataset.map(
41
+ self.augmentation_factory.random_crop, num_parallel_calls=tf.data.AUTOTUNE
42
+ )
43
+ if is_train:
44
+ dataset = (
45
+ dataset.map(
46
+ self.augmentation_factory.random_horizontal_flip,
47
+ num_parallel_calls=tf.data.AUTOTUNE,
48
+ )
49
+ if self.apply_random_horizontal_flip
50
+ else dataset
51
+ )
52
+ dataset = (
53
+ dataset.map(
54
+ self.augmentation_factory.random_vertical_flip,
55
+ num_parallel_calls=tf.data.AUTOTUNE,
56
+ )
57
+ if self.apply_random_vertical_flip
58
+ else dataset
59
+ )
60
+ dataset = (
61
+ dataset.map(
62
+ self.augmentation_factory.random_rotate,
63
+ num_parallel_calls=tf.data.AUTOTUNE,
64
+ )
65
+ if self.apply_random_rotation
66
+ else dataset
67
+ )
68
+ dataset = dataset.batch(batch_size, drop_remainder=True)
69
+ return dataset
70
+
71
+ def get_datasets(
72
+ self,
73
+ low_light_images: List[str],
74
+ enhanced_images: List[str],
75
+ val_split: float = 0.2,
76
+ batch_size: int = 16,
77
+ ):
78
+ assert len(low_light_images) == len(enhanced_images)
79
+ split_index = int(len(low_light_images) * (1 - val_split))
80
+ train_low_light_images = low_light_images[:split_index]
81
+ train_enhanced_images = enhanced_images[:split_index]
82
+ val_low_light_images = low_light_images[split_index:]
83
+ val_enhanced_images = enhanced_images[split_index:]
84
+ print(f"Number of train data points: {len(train_low_light_images)}")
85
+ print(f"Number of validation data points: {len(val_low_light_images)}")
86
+ train_dataset = self._get_dataset(
87
+ train_low_light_images, train_enhanced_images, batch_size, is_train=True
88
+ )
89
+ val_dataset = self._get_dataset(
90
+ val_low_light_images, val_enhanced_images, batch_size, is_train=False
91
+ )
92
+ return train_dataset, val_dataset
enhance_me/mirnet/losses.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import losses
3
+
4
+
5
+ class CharbonnierLoss(losses.Loss):
6
+ def __init__(self, epsilon: float = 1e-3, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+ self.epsilon = epsilon
9
+
10
+ def call(self, y_true, y_pred):
11
+ return tf.reduce_mean(
12
+ tf.sqrt(tf.square(y_true - y_pred) + tf.square(self.epsilon))
13
+ )
enhance_me/mirnet/mirnet.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from PIL import Image
4
+ from typing import List
5
+ from datetime import datetime
6
+
7
+ from tensorflow import keras
8
+ from tensorflow.keras import optimizers, models
9
+
10
+ from wandb.keras import WandbCallback
11
+
12
+ from .dataloader import LowLightDataset
13
+ from .models import build_mirnet_model
14
+ from .losses import CharbonnierLoss
15
+ from ..commons import (
16
+ peak_signal_noise_ratio,
17
+ closest_number,
18
+ init_wandb,
19
+ download_lol_dataset,
20
+ )
21
+
22
+
23
+ class MIRNet:
24
+ def __init__(self, experiment_name=None, wandb_api_key=None) -> None:
25
+ self.experiment_name = experiment_name
26
+ if wandb_api_key is not None:
27
+ init_wandb("mirnet", experiment_name, wandb_api_key)
28
+ self.using_wandb = True
29
+ else:
30
+ self.using_wandb = False
31
+
32
+ def build_datasets(
33
+ self,
34
+ image_size: int = 256,
35
+ dataset_label: str = "lol",
36
+ apply_random_horizontal_flip: bool = True,
37
+ apply_random_vertical_flip: bool = True,
38
+ apply_random_rotation: bool = True,
39
+ val_split: float = 0.2,
40
+ batch_size: int = 16,
41
+ ):
42
+ if dataset_label == "lol":
43
+ (self.low_images, self.enhanced_images), (
44
+ self.test_low_images,
45
+ self.test_enhanced_images,
46
+ ) = download_lol_dataset()
47
+ self.data_loader = LowLightDataset(
48
+ image_size=image_size,
49
+ apply_random_horizontal_flip=apply_random_horizontal_flip,
50
+ apply_random_vertical_flip=apply_random_vertical_flip,
51
+ apply_random_rotation=apply_random_rotation,
52
+ )
53
+ (self.train_dataset, self.val_dataset) = self.data_loader.get_datasets(
54
+ low_light_images=self.low_images,
55
+ enhanced_images=self.enhanced_images,
56
+ val_split=val_split,
57
+ batch_size=batch_size,
58
+ )
59
+
60
+ def build_model(
61
+ self,
62
+ num_recursive_residual_groups: int = 3,
63
+ num_multi_scale_residual_blocks: int = 2,
64
+ channels: int = 64,
65
+ learning_rate: float = 1e-4,
66
+ epsilon: float = 1e-3,
67
+ ):
68
+ self.model = build_mirnet_model(
69
+ num_rrg=num_recursive_residual_groups,
70
+ num_mrb=num_multi_scale_residual_blocks,
71
+ channels=channels,
72
+ )
73
+ self.model.compile(
74
+ optimizer=optimizers.Adam(learning_rate=learning_rate),
75
+ loss=CharbonnierLoss(epsilon=epsilon),
76
+ metrics=[peak_signal_noise_ratio],
77
+ )
78
+
79
+ def load_model(
80
+ self, filepath, custom_objects=None, compile=True, options=None
81
+ ) -> None:
82
+ self.model = models.load_model(
83
+ filepath=filepath,
84
+ custom_objects=custom_objects,
85
+ compile=compile,
86
+ options=options,
87
+ )
88
+
89
+ def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
90
+ self.model.save_weights(
91
+ filepath, overwrite=overwrite, save_format=save_format, options=options
92
+ )
93
+
94
+ def load_weights(self, filepath, by_name=False, skip_mismatch=False, options=None):
95
+ self.model.load_weights(
96
+ filepath, by_name=by_name, skip_mismatch=skip_mismatch, options=options
97
+ )
98
+
99
+ def train(self, epochs: int):
100
+ log_dir = os.path.join(
101
+ self.experiment_name,
102
+ "logs",
103
+ datetime.now().strftime("%Y%m%d-%H%M%S"),
104
+ )
105
+ tensorboard_callback = keras.callbacks.TensorBoard(log_dir, histogram_freq=1)
106
+ model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
107
+ os.path.join(self.experiment_name, "weights.h5"),
108
+ save_best_only=True,
109
+ save_weights_only=True,
110
+ )
111
+ reduce_lr_callback = keras.callbacks.ReduceLROnPlateau(
112
+ monitor="val_peak_signal_noise_ratio",
113
+ factor=0.5,
114
+ patience=5,
115
+ verbose=1,
116
+ min_delta=1e-7,
117
+ mode="max",
118
+ )
119
+ callbacks = [
120
+ tensorboard_callback,
121
+ model_checkpoint_callback,
122
+ reduce_lr_callback,
123
+ ]
124
+ if self.using_wandb:
125
+ callbacks += [WandbCallback()]
126
+ history = self.model.fit(
127
+ self.train_dataset,
128
+ validation_data=self.val_dataset,
129
+ epochs=epochs,
130
+ callbacks=callbacks,
131
+ )
132
+ return history
133
+
134
+ def infer(
135
+ self,
136
+ original_image,
137
+ image_resize_factor: float = 1.0,
138
+ resize_output: bool = False,
139
+ ):
140
+ width, height = original_image.size
141
+ target_width, target_height = (
142
+ closest_number(width // image_resize_factor, 4),
143
+ closest_number(height // image_resize_factor, 4),
144
+ )
145
+ original_image = original_image.resize(
146
+ (target_width, target_height), Image.ANTIALIAS
147
+ )
148
+ image = keras.preprocessing.image.img_to_array(original_image)
149
+ image = image.astype("float32") / 255.0
150
+ image = np.expand_dims(image, axis=0)
151
+ output = self.model.predict(image)
152
+ output_image = output[0] * 255.0
153
+ output_image = output_image.clip(0, 255)
154
+ output_image = output_image.reshape(
155
+ (np.shape(output_image)[0], np.shape(output_image)[1], 3)
156
+ )
157
+ output_image = Image.fromarray(np.uint8(output_image))
158
+ original_image = Image.fromarray(np.uint8(original_image))
159
+ if resize_output:
160
+ output_image = output_image.resize((width, height), Image.ANTIALIAS)
161
+ return output_image
162
+
163
+ def infer_from_file(
164
+ self,
165
+ original_image_file: str,
166
+ image_resize_factor: float = 1.0,
167
+ resize_output: bool = False,
168
+ ):
169
+ original_image = Image.open(original_image_file)
170
+ return self.infer(original_image, image_resize_factor, resize_output)
enhance_me/mirnet/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .mirnet_model import build_mirnet_model
enhance_me/mirnet/models/dual_attention.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import layers
3
+
4
+
5
+ def spatial_attention_block(input_tensor):
6
+ average_pooling = tf.reduce_max(input_tensor, axis=-1)
7
+ average_pooling = tf.expand_dims(average_pooling, axis=-1)
8
+ max_pooling = tf.reduce_mean(input_tensor, axis=-1)
9
+ max_pooling = tf.expand_dims(max_pooling, axis=-1)
10
+ concatenated = layers.Concatenate(axis=-1)([average_pooling, max_pooling])
11
+ feature_map = layers.Conv2D(1, kernel_size=(1, 1))(concatenated)
12
+ feature_map = tf.nn.sigmoid(feature_map)
13
+ return input_tensor * feature_map
14
+
15
+
16
+ def channel_attention_block(input_tensor):
17
+ channels = list(input_tensor.shape)[-1]
18
+ average_pooling = layers.GlobalAveragePooling2D()(input_tensor)
19
+ feature_descriptor = tf.reshape(average_pooling, shape=(-1, 1, 1, channels))
20
+ feature_activations = layers.Conv2D(
21
+ filters=channels // 8, kernel_size=(1, 1), activation="relu"
22
+ )(feature_descriptor)
23
+ feature_activations = layers.Conv2D(
24
+ filters=channels, kernel_size=(1, 1), activation="sigmoid"
25
+ )(feature_activations)
26
+ return input_tensor * feature_activations
27
+
28
+
29
+ def dual_attention_unit_block(input_tensor):
30
+ channels = list(input_tensor.shape)[-1]
31
+ feature_map = layers.Conv2D(
32
+ channels, kernel_size=(3, 3), padding="same", activation="relu"
33
+ )(input_tensor)
34
+ feature_map = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(
35
+ feature_map
36
+ )
37
+ channel_attention = channel_attention_block(feature_map)
38
+ spatial_attention = spatial_attention_block(feature_map)
39
+ concatenation = layers.Concatenate(axis=-1)([channel_attention, spatial_attention])
40
+ concatenation = layers.Conv2D(channels, kernel_size=(1, 1))(concatenation)
41
+ return layers.Add()([input_tensor, concatenation])
enhance_me/mirnet/models/mirnet_model.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tensorflow.keras import layers, Input, Model
2
+
3
+ from .recursive_residual_blocks import recursive_residual_group
4
+
5
+
6
+ def build_mirnet_model(num_rrg, num_mrb, channels):
7
+ input_tensor = Input(shape=[None, None, 3])
8
+ x1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor)
9
+ for _ in range(num_rrg):
10
+ x1 = recursive_residual_group(x1, num_mrb, channels)
11
+ conv = layers.Conv2D(3, kernel_size=(3, 3), padding="same")(x1)
12
+ output_tensor = layers.Add()([input_tensor, conv])
13
+ return Model(input_tensor, output_tensor)
enhance_me/mirnet/models/recursive_residual_blocks.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import layers
3
+
4
+ from .skff import selective_kernel_feature_fusion
5
+ from .dual_attention import dual_attention_unit_block
6
+
7
+
8
+ def down_sampling_module(input_tensor):
9
+ channels = list(input_tensor.shape)[-1]
10
+ main_branch = layers.Conv2D(channels, kernel_size=(1, 1), activation="relu")(
11
+ input_tensor
12
+ )
13
+ main_branch = layers.Conv2D(
14
+ channels, kernel_size=(3, 3), padding="same", activation="relu"
15
+ )(main_branch)
16
+ main_branch = layers.MaxPooling2D()(main_branch)
17
+ main_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(main_branch)
18
+ skip_branch = layers.MaxPooling2D()(input_tensor)
19
+ skip_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(skip_branch)
20
+ return layers.Add()([skip_branch, main_branch])
21
+
22
+
23
+ def up_sampling_module(input_tensor):
24
+ channels = list(input_tensor.shape)[-1]
25
+ main_branch = layers.Conv2D(channels, kernel_size=(1, 1), activation="relu")(
26
+ input_tensor
27
+ )
28
+ main_branch = layers.Conv2D(
29
+ channels, kernel_size=(3, 3), padding="same", activation="relu"
30
+ )(main_branch)
31
+ main_branch = layers.UpSampling2D()(main_branch)
32
+ main_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(main_branch)
33
+ skip_branch = layers.UpSampling2D()(input_tensor)
34
+ skip_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(skip_branch)
35
+ return layers.Add()([skip_branch, main_branch])
36
+
37
+
38
+ # MRB Block
39
+ def multi_scale_residual_block(input_tensor, channels):
40
+ # features
41
+ level1 = input_tensor
42
+ level2 = down_sampling_module(input_tensor)
43
+ level3 = down_sampling_module(level2)
44
+ # DAU
45
+ level1_dau = dual_attention_unit_block(level1)
46
+ level2_dau = dual_attention_unit_block(level2)
47
+ level3_dau = dual_attention_unit_block(level3)
48
+ # SKFF
49
+ level1_skff = selective_kernel_feature_fusion(
50
+ level1_dau,
51
+ up_sampling_module(level2_dau),
52
+ up_sampling_module(up_sampling_module(level3_dau)),
53
+ )
54
+ level2_skff = selective_kernel_feature_fusion(
55
+ down_sampling_module(level1_dau), level2_dau, up_sampling_module(level3_dau)
56
+ )
57
+ level3_skff = selective_kernel_feature_fusion(
58
+ down_sampling_module(down_sampling_module(level1_dau)),
59
+ down_sampling_module(level2_dau),
60
+ level3_dau,
61
+ )
62
+ # DAU 2
63
+ level1_dau_2 = dual_attention_unit_block(level1_skff)
64
+ level2_dau_2 = up_sampling_module((dual_attention_unit_block(level2_skff)))
65
+ level3_dau_2 = up_sampling_module(
66
+ up_sampling_module(dual_attention_unit_block(level3_skff))
67
+ )
68
+ # SKFF 2
69
+ skff_ = selective_kernel_feature_fusion(level1_dau_2, level3_dau_2, level3_dau_2)
70
+ conv = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(skff_)
71
+ return layers.Add()([input_tensor, conv])
72
+
73
+
74
+ def recursive_residual_group(input_tensor, num_mrb, channels):
75
+ conv1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor)
76
+ for _ in range(num_mrb):
77
+ conv1 = multi_scale_residual_block(conv1, channels)
78
+ conv2 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(conv1)
79
+ return layers.Add()([conv2, input_tensor])
enhance_me/mirnet/models/skff.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import layers
3
+
4
+
5
+ def selective_kernel_feature_fusion(
6
+ multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3
7
+ ):
8
+ channels = list(multi_scale_feature_1.shape)[-1]
9
+ combined_feature = layers.Add()(
10
+ [multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3]
11
+ )
12
+ gap = layers.GlobalAveragePooling2D()(combined_feature)
13
+ channel_wise_statistics = tf.reshape(gap, shape=(-1, 1, 1, channels))
14
+ compact_feature_representation = layers.Conv2D(
15
+ filters=channels // 8, kernel_size=(1, 1), activation="relu"
16
+ )(channel_wise_statistics)
17
+ feature_descriptor_1 = layers.Conv2D(
18
+ channels, kernel_size=(1, 1), activation="softmax"
19
+ )(compact_feature_representation)
20
+ feature_descriptor_2 = layers.Conv2D(
21
+ channels, kernel_size=(1, 1), activation="softmax"
22
+ )(compact_feature_representation)
23
+ feature_descriptor_3 = layers.Conv2D(
24
+ channels, kernel_size=(1, 1), activation="softmax"
25
+ )(compact_feature_representation)
26
+ feature_1 = multi_scale_feature_1 * feature_descriptor_1
27
+ feature_2 = multi_scale_feature_2 * feature_descriptor_2
28
+ feature_3 = multi_scale_feature_3 * feature_descriptor_3
29
+ aggregated_feature = layers.Add()([feature_1, feature_2, feature_3])
30
+ return aggregated_feature
notebooks/.gitkeep ADDED
File without changes
notebooks/enhance_me_train.ipynb ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "colab_type": "text",
7
+ "id": "view-in-github"
8
+ },
9
+ "source": [
10
+ "<a href=\"https://colab.research.google.com/github/soumik12345/enhance-me/blob/mirnet/notebooks/enhance_me_train.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "metadata": {
17
+ "colab": {
18
+ "base_uri": "https://localhost:8080/"
19
+ },
20
+ "id": "1JryaVhtBHij",
21
+ "outputId": "97ee6a4a-2479-4124-e96a-f0a792bdec46"
22
+ },
23
+ "outputs": [],
24
+ "source": [
25
+ "!git clone https://github.com/soumik12345/enhance-me -b mirnet\n",
26
+ "!pip install -qqq wandb streamlit"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": null,
32
+ "metadata": {
33
+ "id": "G_c4VtXWHR5l"
34
+ },
35
+ "outputs": [],
36
+ "source": [
37
+ "import os\n",
38
+ "import sys\n",
39
+ "\n",
40
+ "sys.path.append(\"./enhance-me\")\n",
41
+ "\n",
42
+ "from PIL import Image\n",
43
+ "from enhance_me import commons\n",
44
+ "from enhance_me.mirnet import MIRNet"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": null,
50
+ "metadata": {
51
+ "id": "ZpBHbYaMIqP_"
52
+ },
53
+ "outputs": [],
54
+ "source": [
55
+ "# @title MIRNet Train Configs\n",
56
+ "\n",
57
+ "experiment_name = \"lol_dataset_256\" # @param {type:\"string\"}\n",
58
+ "image_size = 128 # @param {type:\"integer\"}\n",
59
+ "dataset_label = \"lol\" # @param [\"lol\"]\n",
60
+ "apply_random_horizontal_flip = True # @param {type:\"boolean\"}\n",
61
+ "apply_random_vertical_flip = True # @param {type:\"boolean\"}\n",
62
+ "apply_random_rotation = True # @param {type:\"boolean\"}\n",
63
+ "wandb_api_key = \"\" # @param {type:\"string\"}\n",
64
+ "val_split = 0.1 # @param {type:\"slider\", min:0.1, max:1.0, step:0.1}\n",
65
+ "batch_size = 4 # @param {type:\"integer\"}\n",
66
+ "num_recursive_residual_groups = 3 # @param {type:\"slider\", min:1, max:5, step:1}\n",
67
+ "num_multi_scale_residual_blocks = 2 # @param {type:\"slider\", min:1, max:5, step:1}\n",
68
+ "learning_rate = 1e-4 # @param {type:\"number\"}\n",
69
+ "epsilon = 1e-3 # @param {type:\"number\"}\n",
70
+ "epochs = 50 # @param {type:\"slider\", min:10, max:100, step:5}"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": null,
76
+ "metadata": {
77
+ "colab": {
78
+ "base_uri": "https://localhost:8080/",
79
+ "height": 52
80
+ },
81
+ "id": "IVRoedqBIMuH",
82
+ "outputId": "53ca5beb-871a-4ec3-b757-173e09a15331"
83
+ },
84
+ "outputs": [],
85
+ "source": [
86
+ "mirnet = MIRNet(\n",
87
+ " experiment_name=experiment_name,\n",
88
+ " wandb_api_key=None if wandb_api_key == \"\" else wandb_api_key,\n",
89
+ ")"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": null,
95
+ "metadata": {
96
+ "colab": {
97
+ "base_uri": "https://localhost:8080/"
98
+ },
99
+ "id": "O66Iwzx8IsGh",
100
+ "outputId": "0b6f1683-65d1-4737-a32f-d36b331d2bc2"
101
+ },
102
+ "outputs": [],
103
+ "source": [
104
+ "mirnet.build_datasets(\n",
105
+ " image_size=image_size,\n",
106
+ " dataset_label=dataset_label,\n",
107
+ " apply_random_horizontal_flip=apply_random_horizontal_flip,\n",
108
+ " apply_random_vertical_flip=apply_random_vertical_flip,\n",
109
+ " apply_random_rotation=apply_random_rotation,\n",
110
+ " val_split=val_split,\n",
111
+ " batch_size=batch_size,\n",
112
+ ")"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": null,
118
+ "metadata": {
119
+ "id": "tsfKrBCsL_Bb"
120
+ },
121
+ "outputs": [],
122
+ "source": [
123
+ "mirnet.build_model(\n",
124
+ " num_recursive_residual_groups=num_recursive_residual_groups,\n",
125
+ " num_multi_scale_residual_blocks=num_multi_scale_residual_blocks,\n",
126
+ " learning_rate=learning_rate,\n",
127
+ " epsilon=epsilon,\n",
128
+ ")"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": null,
134
+ "metadata": {
135
+ "colab": {
136
+ "base_uri": "https://localhost:8080/"
137
+ },
138
+ "id": "y3L9wlpkNziL",
139
+ "outputId": "5149f0e7-91f4-450f-c43a-1b6028692bbc"
140
+ },
141
+ "outputs": [],
142
+ "source": [
143
+ "history = mirnet.train(epochs=epochs)"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": null,
149
+ "metadata": {},
150
+ "outputs": [],
151
+ "source": [
152
+ "mirnet.load_weights(os.path.join(mirnet.experiment_name, \"weights.h5\"))"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": null,
158
+ "metadata": {
159
+ "colab": {
160
+ "background_save": true
161
+ },
162
+ "id": "daFKbgBkiyzc"
163
+ },
164
+ "outputs": [],
165
+ "source": [
166
+ "for index, low_image_file in enumerate(mirnet.test_low_images):\n",
167
+ " original_image = Image.open(low_image_file)\n",
168
+ " enhanced_image = mirnet.infer(original_image)\n",
169
+ " ground_truth = Image.open(mirnet.test_enhanced_images[index])\n",
170
+ " commons.plot_results(\n",
171
+ " [original_image, ground_truth, ground_truth],\n",
172
+ " [\"Original Image\", \"Ground Truth\", \"Enhanced Image\"],\n",
173
+ " (18, 18),\n",
174
+ " )"
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "execution_count": null,
180
+ "metadata": {
181
+ "id": "dO-IbNQHkB3R"
182
+ },
183
+ "outputs": [],
184
+ "source": []
185
+ }
186
+ ],
187
+ "metadata": {
188
+ "accelerator": "GPU",
189
+ "colab": {
190
+ "authorship_tag": "ABX9TyN4LuJh6kWhbqxzA5s9sp7k",
191
+ "collapsed_sections": [],
192
+ "include_colab_link": true,
193
+ "machine_shape": "hm",
194
+ "name": "enhance-me-train.ipynb",
195
+ "provenance": []
196
+ },
197
+ "kernelspec": {
198
+ "display_name": "Python 3",
199
+ "name": "python3"
200
+ },
201
+ "language_info": {
202
+ "name": "python"
203
+ }
204
+ },
205
+ "nbformat": 4,
206
+ "nbformat_minor": 0
207
+ }
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ black
2
+ gdown
3
+ matplotlib
4
+ streamlit
5
+ tensorflow
6
+ tqdm
7
+ wandb