Spaces:
Runtime error
Runtime error
Merge pull request #1 from soumik12345/mirnet
Browse files- .gitignore +5 -0
- README.md +10 -0
- app.py +40 -0
- enhance_me/__init__.py +0 -0
- enhance_me/augmentation.py +51 -0
- enhance_me/commons.py +63 -0
- enhance_me/mirnet/__init__.py +1 -0
- enhance_me/mirnet/dataloader.py +92 -0
- enhance_me/mirnet/losses.py +13 -0
- enhance_me/mirnet/mirnet.py +170 -0
- enhance_me/mirnet/models/__init__.py +1 -0
- enhance_me/mirnet/models/dual_attention.py +41 -0
- enhance_me/mirnet/models/mirnet_model.py +13 -0
- enhance_me/mirnet/models/recursive_residual_blocks.py +79 -0
- enhance_me/mirnet/models/skff.py +30 -0
- notebooks/.gitkeep +0 -0
- notebooks/enhance_me_train.ipynb +207 -0
- requirements.txt +7 -0
.gitignore
CHANGED
@@ -127,3 +127,8 @@ dmypy.json
|
|
127 |
|
128 |
# Pyre type checker
|
129 |
.pyre/
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
# Pyre type checker
|
129 |
.pyre/
|
130 |
+
|
131 |
+
# Datasets
|
132 |
+
datasets/
|
133 |
+
**.zip
|
134 |
+
**.h5
|
README.md
CHANGED
@@ -1 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# enhance-me
|
|
|
1 |
+
---
|
2 |
+
title: Enhance Me
|
3 |
+
emoji: 🌖
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: pink
|
6 |
+
sdk: streamlit
|
7 |
+
app_file: app.py
|
8 |
+
pinned: false
|
9 |
+
---
|
10 |
+
|
11 |
# enhance-me
|
app.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import streamlit as st
|
3 |
+
from tensorflow.keras import utils
|
4 |
+
|
5 |
+
from enhance_me.mirnet import MIRNet
|
6 |
+
|
7 |
+
|
8 |
+
@st.cache
|
9 |
+
def get_mirnet_object() -> MIRNet:
|
10 |
+
mirnet = MIRNet()
|
11 |
+
mirnet.build_model()
|
12 |
+
utils.get_file(
|
13 |
+
"weights_lol_128.h5",
|
14 |
+
"https://github.com/soumik12345/enhance-me/releases/download/v0.2/weights_lol_128.h5",
|
15 |
+
cache_dir=".",
|
16 |
+
cache_subdir="weights",
|
17 |
+
)
|
18 |
+
mirnet.load_weights("./weights/weights_lol_128.h5")
|
19 |
+
return mirnet
|
20 |
+
|
21 |
+
|
22 |
+
def main():
|
23 |
+
st.markdown("# Enhance Me")
|
24 |
+
st.markdown("Made with :heart: by [geekyRakshit](http://github.com/soumik12345)")
|
25 |
+
application = st.sidebar.selectbox(
|
26 |
+
"Please select the application:", ("", "Low-light enhancement")
|
27 |
+
)
|
28 |
+
if application != "":
|
29 |
+
if application == "Low-light enhancement":
|
30 |
+
uploaded_file = st.sidebar.file_uploader("Select your image:")
|
31 |
+
if uploaded_file is not None:
|
32 |
+
original_image = Image.open(uploaded_file)
|
33 |
+
st.image(original_image, caption="original image")
|
34 |
+
mirnet = get_mirnet_object()
|
35 |
+
enhanced_image = mirnet.infer(original_image)
|
36 |
+
st.image(enhanced_image, caption="enhanced image")
|
37 |
+
|
38 |
+
|
39 |
+
if __name__ == "__main__":
|
40 |
+
main()
|
enhance_me/__init__.py
ADDED
File without changes
|
enhance_me/augmentation.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
|
4 |
+
class AugmentationFactory:
|
5 |
+
def __init__(self, image_size) -> None:
|
6 |
+
self.image_size = image_size
|
7 |
+
|
8 |
+
def random_crop(self, input_image, enhanced_image):
|
9 |
+
input_image_shape = tf.shape(input_image)[:2]
|
10 |
+
low_w = tf.random.uniform(
|
11 |
+
shape=(), maxval=input_image_shape[1] - self.image_size + 1, dtype=tf.int32
|
12 |
+
)
|
13 |
+
low_h = tf.random.uniform(
|
14 |
+
shape=(), maxval=input_image_shape[0] - self.image_size + 1, dtype=tf.int32
|
15 |
+
)
|
16 |
+
enhanced_w = low_w
|
17 |
+
enhanced_h = low_h
|
18 |
+
input_image_cropped = input_image[
|
19 |
+
low_h : low_h + self.image_size, low_w : low_w + self.image_size
|
20 |
+
]
|
21 |
+
enhanced_image_cropped = enhanced_image[
|
22 |
+
enhanced_h : enhanced_h + self.image_size,
|
23 |
+
enhanced_w : enhanced_w + self.image_size,
|
24 |
+
]
|
25 |
+
return input_image_cropped, enhanced_image_cropped
|
26 |
+
|
27 |
+
def random_horizontal_flip(sefl, input_image, enhanced_image):
|
28 |
+
return tf.cond(
|
29 |
+
tf.random.uniform(shape=(), maxval=1) < 0.5,
|
30 |
+
lambda: (input_image, enhanced_image),
|
31 |
+
lambda: (
|
32 |
+
tf.image.flip_left_right(input_image),
|
33 |
+
tf.image.flip_left_right(enhanced_image),
|
34 |
+
),
|
35 |
+
)
|
36 |
+
|
37 |
+
def random_vertical_flip(self, input_image, enhanced_image):
|
38 |
+
return tf.cond(
|
39 |
+
tf.random.uniform(shape=(), maxval=1) < 0.5,
|
40 |
+
lambda: (input_image, enhanced_image),
|
41 |
+
lambda: (
|
42 |
+
tf.image.flip_up_down(input_image),
|
43 |
+
tf.image.flip_up_down(enhanced_image),
|
44 |
+
),
|
45 |
+
)
|
46 |
+
|
47 |
+
def random_rotate(self, input_image, enhanced_image):
|
48 |
+
condition = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)
|
49 |
+
return tf.image.rot90(input_image, condition), tf.image.rot90(
|
50 |
+
enhanced_image, condition
|
51 |
+
)
|
enhance_me/commons.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import wandb
|
3 |
+
from glob import glob
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
|
6 |
+
import tensorflow as tf
|
7 |
+
from tensorflow.keras import utils
|
8 |
+
|
9 |
+
|
10 |
+
def read_image(image_path):
|
11 |
+
image = tf.io.read_file(image_path)
|
12 |
+
image = tf.image.decode_png(image, channels=3)
|
13 |
+
image.set_shape([None, None, 3])
|
14 |
+
image = tf.cast(image, dtype=tf.float32) / 255.0
|
15 |
+
return image
|
16 |
+
|
17 |
+
|
18 |
+
def peak_signal_noise_ratio(y_true, y_pred):
|
19 |
+
return tf.image.psnr(y_pred, y_true, max_val=255.0)
|
20 |
+
|
21 |
+
|
22 |
+
def plot_results(images, titles, figure_size=(12, 12)):
|
23 |
+
fig = plt.figure(figsize=figure_size)
|
24 |
+
for i in range(len(images)):
|
25 |
+
fig.add_subplot(1, len(images), i + 1).set_title(titles[i])
|
26 |
+
_ = plt.imshow(images[i])
|
27 |
+
plt.axis("off")
|
28 |
+
plt.show()
|
29 |
+
|
30 |
+
|
31 |
+
def closest_number(n, m):
|
32 |
+
q = int(n / m)
|
33 |
+
n1 = m * q
|
34 |
+
if (n * m) > 0:
|
35 |
+
n2 = m * (q + 1)
|
36 |
+
else:
|
37 |
+
n2 = m * (q - 1)
|
38 |
+
if abs(n - n1) < abs(n - n2):
|
39 |
+
return n1
|
40 |
+
return n2
|
41 |
+
|
42 |
+
|
43 |
+
def init_wandb(project_name, experiment_name, wandb_api_key):
|
44 |
+
if project_name is not None and experiment_name is not None:
|
45 |
+
os.environ["WANDB_API_KEY"] = wandb_api_key
|
46 |
+
wandb.init(project=project_name, name=experiment_name, sync_tensorboard=True)
|
47 |
+
|
48 |
+
|
49 |
+
def download_lol_dataset():
|
50 |
+
utils.get_file(
|
51 |
+
"lol_dataset.zip",
|
52 |
+
"https://github.com/soumik12345/enhance-me/releases/download/v0.1/lol_dataset.zip",
|
53 |
+
cache_dir="./",
|
54 |
+
cache_subdir="./datasets",
|
55 |
+
extract=True,
|
56 |
+
)
|
57 |
+
low_images = sorted(glob("./datasets/lol_dataset/our485/low/*"))
|
58 |
+
enhanced_images = sorted(glob("./datasets/lol_dataset/our485/high/*"))
|
59 |
+
assert len(low_images) == len(enhanced_images)
|
60 |
+
test_low_images = sorted(glob("./datasets/lol_dataset/eval15/low/*"))
|
61 |
+
test_enhanced_images = sorted(glob("./datasets/lol_dataset/eval15/high/*"))
|
62 |
+
assert len(test_low_images) == len(test_enhanced_images)
|
63 |
+
return (low_images, enhanced_images), (test_low_images, test_enhanced_images)
|
enhance_me/mirnet/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .mirnet import MIRNet
|
enhance_me/mirnet/dataloader.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
from ..commons import read_image
|
5 |
+
from ..augmentation import AugmentationFactory
|
6 |
+
|
7 |
+
|
8 |
+
class LowLightDataset:
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
image_size: int = 256,
|
12 |
+
apply_random_horizontal_flip: bool = True,
|
13 |
+
apply_random_vertical_flip: bool = True,
|
14 |
+
apply_random_rotation: bool = True,
|
15 |
+
) -> None:
|
16 |
+
self.augmentation_factory = AugmentationFactory(image_size=image_size)
|
17 |
+
self.apply_random_horizontal_flip = apply_random_horizontal_flip
|
18 |
+
self.apply_random_vertical_flip = apply_random_vertical_flip
|
19 |
+
self.apply_random_rotation = apply_random_rotation
|
20 |
+
|
21 |
+
def load_data(self, low_light_image_path, enhanced_image_path):
|
22 |
+
low_light_image = read_image(low_light_image_path)
|
23 |
+
enhanced_image = read_image(enhanced_image_path)
|
24 |
+
low_light_image, enhanced_image = self.augmentation_factory.random_crop(
|
25 |
+
low_light_image, enhanced_image
|
26 |
+
)
|
27 |
+
return low_light_image, enhanced_image
|
28 |
+
|
29 |
+
def _get_dataset(
|
30 |
+
self,
|
31 |
+
low_light_images: List[str],
|
32 |
+
enhanced_images: List[str],
|
33 |
+
batch_size: int = 16,
|
34 |
+
is_train: bool = True,
|
35 |
+
):
|
36 |
+
dataset = tf.data.Dataset.from_tensor_slices(
|
37 |
+
(low_light_images, enhanced_images)
|
38 |
+
)
|
39 |
+
dataset = dataset.map(self.load_data, num_parallel_calls=tf.data.AUTOTUNE)
|
40 |
+
dataset = dataset.map(
|
41 |
+
self.augmentation_factory.random_crop, num_parallel_calls=tf.data.AUTOTUNE
|
42 |
+
)
|
43 |
+
if is_train:
|
44 |
+
dataset = (
|
45 |
+
dataset.map(
|
46 |
+
self.augmentation_factory.random_horizontal_flip,
|
47 |
+
num_parallel_calls=tf.data.AUTOTUNE,
|
48 |
+
)
|
49 |
+
if self.apply_random_horizontal_flip
|
50 |
+
else dataset
|
51 |
+
)
|
52 |
+
dataset = (
|
53 |
+
dataset.map(
|
54 |
+
self.augmentation_factory.random_vertical_flip,
|
55 |
+
num_parallel_calls=tf.data.AUTOTUNE,
|
56 |
+
)
|
57 |
+
if self.apply_random_vertical_flip
|
58 |
+
else dataset
|
59 |
+
)
|
60 |
+
dataset = (
|
61 |
+
dataset.map(
|
62 |
+
self.augmentation_factory.random_rotate,
|
63 |
+
num_parallel_calls=tf.data.AUTOTUNE,
|
64 |
+
)
|
65 |
+
if self.apply_random_rotation
|
66 |
+
else dataset
|
67 |
+
)
|
68 |
+
dataset = dataset.batch(batch_size, drop_remainder=True)
|
69 |
+
return dataset
|
70 |
+
|
71 |
+
def get_datasets(
|
72 |
+
self,
|
73 |
+
low_light_images: List[str],
|
74 |
+
enhanced_images: List[str],
|
75 |
+
val_split: float = 0.2,
|
76 |
+
batch_size: int = 16,
|
77 |
+
):
|
78 |
+
assert len(low_light_images) == len(enhanced_images)
|
79 |
+
split_index = int(len(low_light_images) * (1 - val_split))
|
80 |
+
train_low_light_images = low_light_images[:split_index]
|
81 |
+
train_enhanced_images = enhanced_images[:split_index]
|
82 |
+
val_low_light_images = low_light_images[split_index:]
|
83 |
+
val_enhanced_images = enhanced_images[split_index:]
|
84 |
+
print(f"Number of train data points: {len(train_low_light_images)}")
|
85 |
+
print(f"Number of validation data points: {len(val_low_light_images)}")
|
86 |
+
train_dataset = self._get_dataset(
|
87 |
+
train_low_light_images, train_enhanced_images, batch_size, is_train=True
|
88 |
+
)
|
89 |
+
val_dataset = self._get_dataset(
|
90 |
+
val_low_light_images, val_enhanced_images, batch_size, is_train=False
|
91 |
+
)
|
92 |
+
return train_dataset, val_dataset
|
enhance_me/mirnet/losses.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow.keras import losses
|
3 |
+
|
4 |
+
|
5 |
+
class CharbonnierLoss(losses.Loss):
|
6 |
+
def __init__(self, epsilon: float = 1e-3, *args, **kwargs):
|
7 |
+
super().__init__(*args, **kwargs)
|
8 |
+
self.epsilon = epsilon
|
9 |
+
|
10 |
+
def call(self, y_true, y_pred):
|
11 |
+
return tf.reduce_mean(
|
12 |
+
tf.sqrt(tf.square(y_true - y_pred) + tf.square(self.epsilon))
|
13 |
+
)
|
enhance_me/mirnet/mirnet.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
from typing import List
|
5 |
+
from datetime import datetime
|
6 |
+
|
7 |
+
from tensorflow import keras
|
8 |
+
from tensorflow.keras import optimizers, models
|
9 |
+
|
10 |
+
from wandb.keras import WandbCallback
|
11 |
+
|
12 |
+
from .dataloader import LowLightDataset
|
13 |
+
from .models import build_mirnet_model
|
14 |
+
from .losses import CharbonnierLoss
|
15 |
+
from ..commons import (
|
16 |
+
peak_signal_noise_ratio,
|
17 |
+
closest_number,
|
18 |
+
init_wandb,
|
19 |
+
download_lol_dataset,
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
class MIRNet:
|
24 |
+
def __init__(self, experiment_name=None, wandb_api_key=None) -> None:
|
25 |
+
self.experiment_name = experiment_name
|
26 |
+
if wandb_api_key is not None:
|
27 |
+
init_wandb("mirnet", experiment_name, wandb_api_key)
|
28 |
+
self.using_wandb = True
|
29 |
+
else:
|
30 |
+
self.using_wandb = False
|
31 |
+
|
32 |
+
def build_datasets(
|
33 |
+
self,
|
34 |
+
image_size: int = 256,
|
35 |
+
dataset_label: str = "lol",
|
36 |
+
apply_random_horizontal_flip: bool = True,
|
37 |
+
apply_random_vertical_flip: bool = True,
|
38 |
+
apply_random_rotation: bool = True,
|
39 |
+
val_split: float = 0.2,
|
40 |
+
batch_size: int = 16,
|
41 |
+
):
|
42 |
+
if dataset_label == "lol":
|
43 |
+
(self.low_images, self.enhanced_images), (
|
44 |
+
self.test_low_images,
|
45 |
+
self.test_enhanced_images,
|
46 |
+
) = download_lol_dataset()
|
47 |
+
self.data_loader = LowLightDataset(
|
48 |
+
image_size=image_size,
|
49 |
+
apply_random_horizontal_flip=apply_random_horizontal_flip,
|
50 |
+
apply_random_vertical_flip=apply_random_vertical_flip,
|
51 |
+
apply_random_rotation=apply_random_rotation,
|
52 |
+
)
|
53 |
+
(self.train_dataset, self.val_dataset) = self.data_loader.get_datasets(
|
54 |
+
low_light_images=self.low_images,
|
55 |
+
enhanced_images=self.enhanced_images,
|
56 |
+
val_split=val_split,
|
57 |
+
batch_size=batch_size,
|
58 |
+
)
|
59 |
+
|
60 |
+
def build_model(
|
61 |
+
self,
|
62 |
+
num_recursive_residual_groups: int = 3,
|
63 |
+
num_multi_scale_residual_blocks: int = 2,
|
64 |
+
channels: int = 64,
|
65 |
+
learning_rate: float = 1e-4,
|
66 |
+
epsilon: float = 1e-3,
|
67 |
+
):
|
68 |
+
self.model = build_mirnet_model(
|
69 |
+
num_rrg=num_recursive_residual_groups,
|
70 |
+
num_mrb=num_multi_scale_residual_blocks,
|
71 |
+
channels=channels,
|
72 |
+
)
|
73 |
+
self.model.compile(
|
74 |
+
optimizer=optimizers.Adam(learning_rate=learning_rate),
|
75 |
+
loss=CharbonnierLoss(epsilon=epsilon),
|
76 |
+
metrics=[peak_signal_noise_ratio],
|
77 |
+
)
|
78 |
+
|
79 |
+
def load_model(
|
80 |
+
self, filepath, custom_objects=None, compile=True, options=None
|
81 |
+
) -> None:
|
82 |
+
self.model = models.load_model(
|
83 |
+
filepath=filepath,
|
84 |
+
custom_objects=custom_objects,
|
85 |
+
compile=compile,
|
86 |
+
options=options,
|
87 |
+
)
|
88 |
+
|
89 |
+
def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
|
90 |
+
self.model.save_weights(
|
91 |
+
filepath, overwrite=overwrite, save_format=save_format, options=options
|
92 |
+
)
|
93 |
+
|
94 |
+
def load_weights(self, filepath, by_name=False, skip_mismatch=False, options=None):
|
95 |
+
self.model.load_weights(
|
96 |
+
filepath, by_name=by_name, skip_mismatch=skip_mismatch, options=options
|
97 |
+
)
|
98 |
+
|
99 |
+
def train(self, epochs: int):
|
100 |
+
log_dir = os.path.join(
|
101 |
+
self.experiment_name,
|
102 |
+
"logs",
|
103 |
+
datetime.now().strftime("%Y%m%d-%H%M%S"),
|
104 |
+
)
|
105 |
+
tensorboard_callback = keras.callbacks.TensorBoard(log_dir, histogram_freq=1)
|
106 |
+
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
|
107 |
+
os.path.join(self.experiment_name, "weights.h5"),
|
108 |
+
save_best_only=True,
|
109 |
+
save_weights_only=True,
|
110 |
+
)
|
111 |
+
reduce_lr_callback = keras.callbacks.ReduceLROnPlateau(
|
112 |
+
monitor="val_peak_signal_noise_ratio",
|
113 |
+
factor=0.5,
|
114 |
+
patience=5,
|
115 |
+
verbose=1,
|
116 |
+
min_delta=1e-7,
|
117 |
+
mode="max",
|
118 |
+
)
|
119 |
+
callbacks = [
|
120 |
+
tensorboard_callback,
|
121 |
+
model_checkpoint_callback,
|
122 |
+
reduce_lr_callback,
|
123 |
+
]
|
124 |
+
if self.using_wandb:
|
125 |
+
callbacks += [WandbCallback()]
|
126 |
+
history = self.model.fit(
|
127 |
+
self.train_dataset,
|
128 |
+
validation_data=self.val_dataset,
|
129 |
+
epochs=epochs,
|
130 |
+
callbacks=callbacks,
|
131 |
+
)
|
132 |
+
return history
|
133 |
+
|
134 |
+
def infer(
|
135 |
+
self,
|
136 |
+
original_image,
|
137 |
+
image_resize_factor: float = 1.0,
|
138 |
+
resize_output: bool = False,
|
139 |
+
):
|
140 |
+
width, height = original_image.size
|
141 |
+
target_width, target_height = (
|
142 |
+
closest_number(width // image_resize_factor, 4),
|
143 |
+
closest_number(height // image_resize_factor, 4),
|
144 |
+
)
|
145 |
+
original_image = original_image.resize(
|
146 |
+
(target_width, target_height), Image.ANTIALIAS
|
147 |
+
)
|
148 |
+
image = keras.preprocessing.image.img_to_array(original_image)
|
149 |
+
image = image.astype("float32") / 255.0
|
150 |
+
image = np.expand_dims(image, axis=0)
|
151 |
+
output = self.model.predict(image)
|
152 |
+
output_image = output[0] * 255.0
|
153 |
+
output_image = output_image.clip(0, 255)
|
154 |
+
output_image = output_image.reshape(
|
155 |
+
(np.shape(output_image)[0], np.shape(output_image)[1], 3)
|
156 |
+
)
|
157 |
+
output_image = Image.fromarray(np.uint8(output_image))
|
158 |
+
original_image = Image.fromarray(np.uint8(original_image))
|
159 |
+
if resize_output:
|
160 |
+
output_image = output_image.resize((width, height), Image.ANTIALIAS)
|
161 |
+
return output_image
|
162 |
+
|
163 |
+
def infer_from_file(
|
164 |
+
self,
|
165 |
+
original_image_file: str,
|
166 |
+
image_resize_factor: float = 1.0,
|
167 |
+
resize_output: bool = False,
|
168 |
+
):
|
169 |
+
original_image = Image.open(original_image_file)
|
170 |
+
return self.infer(original_image, image_resize_factor, resize_output)
|
enhance_me/mirnet/models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .mirnet_model import build_mirnet_model
|
enhance_me/mirnet/models/dual_attention.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow.keras import layers
|
3 |
+
|
4 |
+
|
5 |
+
def spatial_attention_block(input_tensor):
|
6 |
+
average_pooling = tf.reduce_max(input_tensor, axis=-1)
|
7 |
+
average_pooling = tf.expand_dims(average_pooling, axis=-1)
|
8 |
+
max_pooling = tf.reduce_mean(input_tensor, axis=-1)
|
9 |
+
max_pooling = tf.expand_dims(max_pooling, axis=-1)
|
10 |
+
concatenated = layers.Concatenate(axis=-1)([average_pooling, max_pooling])
|
11 |
+
feature_map = layers.Conv2D(1, kernel_size=(1, 1))(concatenated)
|
12 |
+
feature_map = tf.nn.sigmoid(feature_map)
|
13 |
+
return input_tensor * feature_map
|
14 |
+
|
15 |
+
|
16 |
+
def channel_attention_block(input_tensor):
|
17 |
+
channels = list(input_tensor.shape)[-1]
|
18 |
+
average_pooling = layers.GlobalAveragePooling2D()(input_tensor)
|
19 |
+
feature_descriptor = tf.reshape(average_pooling, shape=(-1, 1, 1, channels))
|
20 |
+
feature_activations = layers.Conv2D(
|
21 |
+
filters=channels // 8, kernel_size=(1, 1), activation="relu"
|
22 |
+
)(feature_descriptor)
|
23 |
+
feature_activations = layers.Conv2D(
|
24 |
+
filters=channels, kernel_size=(1, 1), activation="sigmoid"
|
25 |
+
)(feature_activations)
|
26 |
+
return input_tensor * feature_activations
|
27 |
+
|
28 |
+
|
29 |
+
def dual_attention_unit_block(input_tensor):
|
30 |
+
channels = list(input_tensor.shape)[-1]
|
31 |
+
feature_map = layers.Conv2D(
|
32 |
+
channels, kernel_size=(3, 3), padding="same", activation="relu"
|
33 |
+
)(input_tensor)
|
34 |
+
feature_map = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(
|
35 |
+
feature_map
|
36 |
+
)
|
37 |
+
channel_attention = channel_attention_block(feature_map)
|
38 |
+
spatial_attention = spatial_attention_block(feature_map)
|
39 |
+
concatenation = layers.Concatenate(axis=-1)([channel_attention, spatial_attention])
|
40 |
+
concatenation = layers.Conv2D(channels, kernel_size=(1, 1))(concatenation)
|
41 |
+
return layers.Add()([input_tensor, concatenation])
|
enhance_me/mirnet/models/mirnet_model.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tensorflow.keras import layers, Input, Model
|
2 |
+
|
3 |
+
from .recursive_residual_blocks import recursive_residual_group
|
4 |
+
|
5 |
+
|
6 |
+
def build_mirnet_model(num_rrg, num_mrb, channels):
|
7 |
+
input_tensor = Input(shape=[None, None, 3])
|
8 |
+
x1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor)
|
9 |
+
for _ in range(num_rrg):
|
10 |
+
x1 = recursive_residual_group(x1, num_mrb, channels)
|
11 |
+
conv = layers.Conv2D(3, kernel_size=(3, 3), padding="same")(x1)
|
12 |
+
output_tensor = layers.Add()([input_tensor, conv])
|
13 |
+
return Model(input_tensor, output_tensor)
|
enhance_me/mirnet/models/recursive_residual_blocks.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow.keras import layers
|
3 |
+
|
4 |
+
from .skff import selective_kernel_feature_fusion
|
5 |
+
from .dual_attention import dual_attention_unit_block
|
6 |
+
|
7 |
+
|
8 |
+
def down_sampling_module(input_tensor):
|
9 |
+
channels = list(input_tensor.shape)[-1]
|
10 |
+
main_branch = layers.Conv2D(channels, kernel_size=(1, 1), activation="relu")(
|
11 |
+
input_tensor
|
12 |
+
)
|
13 |
+
main_branch = layers.Conv2D(
|
14 |
+
channels, kernel_size=(3, 3), padding="same", activation="relu"
|
15 |
+
)(main_branch)
|
16 |
+
main_branch = layers.MaxPooling2D()(main_branch)
|
17 |
+
main_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(main_branch)
|
18 |
+
skip_branch = layers.MaxPooling2D()(input_tensor)
|
19 |
+
skip_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(skip_branch)
|
20 |
+
return layers.Add()([skip_branch, main_branch])
|
21 |
+
|
22 |
+
|
23 |
+
def up_sampling_module(input_tensor):
|
24 |
+
channels = list(input_tensor.shape)[-1]
|
25 |
+
main_branch = layers.Conv2D(channels, kernel_size=(1, 1), activation="relu")(
|
26 |
+
input_tensor
|
27 |
+
)
|
28 |
+
main_branch = layers.Conv2D(
|
29 |
+
channels, kernel_size=(3, 3), padding="same", activation="relu"
|
30 |
+
)(main_branch)
|
31 |
+
main_branch = layers.UpSampling2D()(main_branch)
|
32 |
+
main_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(main_branch)
|
33 |
+
skip_branch = layers.UpSampling2D()(input_tensor)
|
34 |
+
skip_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(skip_branch)
|
35 |
+
return layers.Add()([skip_branch, main_branch])
|
36 |
+
|
37 |
+
|
38 |
+
# MRB Block
|
39 |
+
def multi_scale_residual_block(input_tensor, channels):
|
40 |
+
# features
|
41 |
+
level1 = input_tensor
|
42 |
+
level2 = down_sampling_module(input_tensor)
|
43 |
+
level3 = down_sampling_module(level2)
|
44 |
+
# DAU
|
45 |
+
level1_dau = dual_attention_unit_block(level1)
|
46 |
+
level2_dau = dual_attention_unit_block(level2)
|
47 |
+
level3_dau = dual_attention_unit_block(level3)
|
48 |
+
# SKFF
|
49 |
+
level1_skff = selective_kernel_feature_fusion(
|
50 |
+
level1_dau,
|
51 |
+
up_sampling_module(level2_dau),
|
52 |
+
up_sampling_module(up_sampling_module(level3_dau)),
|
53 |
+
)
|
54 |
+
level2_skff = selective_kernel_feature_fusion(
|
55 |
+
down_sampling_module(level1_dau), level2_dau, up_sampling_module(level3_dau)
|
56 |
+
)
|
57 |
+
level3_skff = selective_kernel_feature_fusion(
|
58 |
+
down_sampling_module(down_sampling_module(level1_dau)),
|
59 |
+
down_sampling_module(level2_dau),
|
60 |
+
level3_dau,
|
61 |
+
)
|
62 |
+
# DAU 2
|
63 |
+
level1_dau_2 = dual_attention_unit_block(level1_skff)
|
64 |
+
level2_dau_2 = up_sampling_module((dual_attention_unit_block(level2_skff)))
|
65 |
+
level3_dau_2 = up_sampling_module(
|
66 |
+
up_sampling_module(dual_attention_unit_block(level3_skff))
|
67 |
+
)
|
68 |
+
# SKFF 2
|
69 |
+
skff_ = selective_kernel_feature_fusion(level1_dau_2, level3_dau_2, level3_dau_2)
|
70 |
+
conv = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(skff_)
|
71 |
+
return layers.Add()([input_tensor, conv])
|
72 |
+
|
73 |
+
|
74 |
+
def recursive_residual_group(input_tensor, num_mrb, channels):
|
75 |
+
conv1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor)
|
76 |
+
for _ in range(num_mrb):
|
77 |
+
conv1 = multi_scale_residual_block(conv1, channels)
|
78 |
+
conv2 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(conv1)
|
79 |
+
return layers.Add()([conv2, input_tensor])
|
enhance_me/mirnet/models/skff.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow.keras import layers
|
3 |
+
|
4 |
+
|
5 |
+
def selective_kernel_feature_fusion(
|
6 |
+
multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3
|
7 |
+
):
|
8 |
+
channels = list(multi_scale_feature_1.shape)[-1]
|
9 |
+
combined_feature = layers.Add()(
|
10 |
+
[multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3]
|
11 |
+
)
|
12 |
+
gap = layers.GlobalAveragePooling2D()(combined_feature)
|
13 |
+
channel_wise_statistics = tf.reshape(gap, shape=(-1, 1, 1, channels))
|
14 |
+
compact_feature_representation = layers.Conv2D(
|
15 |
+
filters=channels // 8, kernel_size=(1, 1), activation="relu"
|
16 |
+
)(channel_wise_statistics)
|
17 |
+
feature_descriptor_1 = layers.Conv2D(
|
18 |
+
channels, kernel_size=(1, 1), activation="softmax"
|
19 |
+
)(compact_feature_representation)
|
20 |
+
feature_descriptor_2 = layers.Conv2D(
|
21 |
+
channels, kernel_size=(1, 1), activation="softmax"
|
22 |
+
)(compact_feature_representation)
|
23 |
+
feature_descriptor_3 = layers.Conv2D(
|
24 |
+
channels, kernel_size=(1, 1), activation="softmax"
|
25 |
+
)(compact_feature_representation)
|
26 |
+
feature_1 = multi_scale_feature_1 * feature_descriptor_1
|
27 |
+
feature_2 = multi_scale_feature_2 * feature_descriptor_2
|
28 |
+
feature_3 = multi_scale_feature_3 * feature_descriptor_3
|
29 |
+
aggregated_feature = layers.Add()([feature_1, feature_2, feature_3])
|
30 |
+
return aggregated_feature
|
notebooks/.gitkeep
ADDED
File without changes
|
notebooks/enhance_me_train.ipynb
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"colab_type": "text",
|
7 |
+
"id": "view-in-github"
|
8 |
+
},
|
9 |
+
"source": [
|
10 |
+
"<a href=\"https://colab.research.google.com/github/soumik12345/enhance-me/blob/mirnet/notebooks/enhance_me_train.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "code",
|
15 |
+
"execution_count": null,
|
16 |
+
"metadata": {
|
17 |
+
"colab": {
|
18 |
+
"base_uri": "https://localhost:8080/"
|
19 |
+
},
|
20 |
+
"id": "1JryaVhtBHij",
|
21 |
+
"outputId": "97ee6a4a-2479-4124-e96a-f0a792bdec46"
|
22 |
+
},
|
23 |
+
"outputs": [],
|
24 |
+
"source": [
|
25 |
+
"!git clone https://github.com/soumik12345/enhance-me -b mirnet\n",
|
26 |
+
"!pip install -qqq wandb streamlit"
|
27 |
+
]
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"cell_type": "code",
|
31 |
+
"execution_count": null,
|
32 |
+
"metadata": {
|
33 |
+
"id": "G_c4VtXWHR5l"
|
34 |
+
},
|
35 |
+
"outputs": [],
|
36 |
+
"source": [
|
37 |
+
"import os\n",
|
38 |
+
"import sys\n",
|
39 |
+
"\n",
|
40 |
+
"sys.path.append(\"./enhance-me\")\n",
|
41 |
+
"\n",
|
42 |
+
"from PIL import Image\n",
|
43 |
+
"from enhance_me import commons\n",
|
44 |
+
"from enhance_me.mirnet import MIRNet"
|
45 |
+
]
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"cell_type": "code",
|
49 |
+
"execution_count": null,
|
50 |
+
"metadata": {
|
51 |
+
"id": "ZpBHbYaMIqP_"
|
52 |
+
},
|
53 |
+
"outputs": [],
|
54 |
+
"source": [
|
55 |
+
"# @title MIRNet Train Configs\n",
|
56 |
+
"\n",
|
57 |
+
"experiment_name = \"lol_dataset_256\" # @param {type:\"string\"}\n",
|
58 |
+
"image_size = 128 # @param {type:\"integer\"}\n",
|
59 |
+
"dataset_label = \"lol\" # @param [\"lol\"]\n",
|
60 |
+
"apply_random_horizontal_flip = True # @param {type:\"boolean\"}\n",
|
61 |
+
"apply_random_vertical_flip = True # @param {type:\"boolean\"}\n",
|
62 |
+
"apply_random_rotation = True # @param {type:\"boolean\"}\n",
|
63 |
+
"wandb_api_key = \"\" # @param {type:\"string\"}\n",
|
64 |
+
"val_split = 0.1 # @param {type:\"slider\", min:0.1, max:1.0, step:0.1}\n",
|
65 |
+
"batch_size = 4 # @param {type:\"integer\"}\n",
|
66 |
+
"num_recursive_residual_groups = 3 # @param {type:\"slider\", min:1, max:5, step:1}\n",
|
67 |
+
"num_multi_scale_residual_blocks = 2 # @param {type:\"slider\", min:1, max:5, step:1}\n",
|
68 |
+
"learning_rate = 1e-4 # @param {type:\"number\"}\n",
|
69 |
+
"epsilon = 1e-3 # @param {type:\"number\"}\n",
|
70 |
+
"epochs = 50 # @param {type:\"slider\", min:10, max:100, step:5}"
|
71 |
+
]
|
72 |
+
},
|
73 |
+
{
|
74 |
+
"cell_type": "code",
|
75 |
+
"execution_count": null,
|
76 |
+
"metadata": {
|
77 |
+
"colab": {
|
78 |
+
"base_uri": "https://localhost:8080/",
|
79 |
+
"height": 52
|
80 |
+
},
|
81 |
+
"id": "IVRoedqBIMuH",
|
82 |
+
"outputId": "53ca5beb-871a-4ec3-b757-173e09a15331"
|
83 |
+
},
|
84 |
+
"outputs": [],
|
85 |
+
"source": [
|
86 |
+
"mirnet = MIRNet(\n",
|
87 |
+
" experiment_name=experiment_name,\n",
|
88 |
+
" wandb_api_key=None if wandb_api_key == \"\" else wandb_api_key,\n",
|
89 |
+
")"
|
90 |
+
]
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"cell_type": "code",
|
94 |
+
"execution_count": null,
|
95 |
+
"metadata": {
|
96 |
+
"colab": {
|
97 |
+
"base_uri": "https://localhost:8080/"
|
98 |
+
},
|
99 |
+
"id": "O66Iwzx8IsGh",
|
100 |
+
"outputId": "0b6f1683-65d1-4737-a32f-d36b331d2bc2"
|
101 |
+
},
|
102 |
+
"outputs": [],
|
103 |
+
"source": [
|
104 |
+
"mirnet.build_datasets(\n",
|
105 |
+
" image_size=image_size,\n",
|
106 |
+
" dataset_label=dataset_label,\n",
|
107 |
+
" apply_random_horizontal_flip=apply_random_horizontal_flip,\n",
|
108 |
+
" apply_random_vertical_flip=apply_random_vertical_flip,\n",
|
109 |
+
" apply_random_rotation=apply_random_rotation,\n",
|
110 |
+
" val_split=val_split,\n",
|
111 |
+
" batch_size=batch_size,\n",
|
112 |
+
")"
|
113 |
+
]
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"cell_type": "code",
|
117 |
+
"execution_count": null,
|
118 |
+
"metadata": {
|
119 |
+
"id": "tsfKrBCsL_Bb"
|
120 |
+
},
|
121 |
+
"outputs": [],
|
122 |
+
"source": [
|
123 |
+
"mirnet.build_model(\n",
|
124 |
+
" num_recursive_residual_groups=num_recursive_residual_groups,\n",
|
125 |
+
" num_multi_scale_residual_blocks=num_multi_scale_residual_blocks,\n",
|
126 |
+
" learning_rate=learning_rate,\n",
|
127 |
+
" epsilon=epsilon,\n",
|
128 |
+
")"
|
129 |
+
]
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"cell_type": "code",
|
133 |
+
"execution_count": null,
|
134 |
+
"metadata": {
|
135 |
+
"colab": {
|
136 |
+
"base_uri": "https://localhost:8080/"
|
137 |
+
},
|
138 |
+
"id": "y3L9wlpkNziL",
|
139 |
+
"outputId": "5149f0e7-91f4-450f-c43a-1b6028692bbc"
|
140 |
+
},
|
141 |
+
"outputs": [],
|
142 |
+
"source": [
|
143 |
+
"history = mirnet.train(epochs=epochs)"
|
144 |
+
]
|
145 |
+
},
|
146 |
+
{
|
147 |
+
"cell_type": "code",
|
148 |
+
"execution_count": null,
|
149 |
+
"metadata": {},
|
150 |
+
"outputs": [],
|
151 |
+
"source": [
|
152 |
+
"mirnet.load_weights(os.path.join(mirnet.experiment_name, \"weights.h5\"))"
|
153 |
+
]
|
154 |
+
},
|
155 |
+
{
|
156 |
+
"cell_type": "code",
|
157 |
+
"execution_count": null,
|
158 |
+
"metadata": {
|
159 |
+
"colab": {
|
160 |
+
"background_save": true
|
161 |
+
},
|
162 |
+
"id": "daFKbgBkiyzc"
|
163 |
+
},
|
164 |
+
"outputs": [],
|
165 |
+
"source": [
|
166 |
+
"for index, low_image_file in enumerate(mirnet.test_low_images):\n",
|
167 |
+
" original_image = Image.open(low_image_file)\n",
|
168 |
+
" enhanced_image = mirnet.infer(original_image)\n",
|
169 |
+
" ground_truth = Image.open(mirnet.test_enhanced_images[index])\n",
|
170 |
+
" commons.plot_results(\n",
|
171 |
+
" [original_image, ground_truth, ground_truth],\n",
|
172 |
+
" [\"Original Image\", \"Ground Truth\", \"Enhanced Image\"],\n",
|
173 |
+
" (18, 18),\n",
|
174 |
+
" )"
|
175 |
+
]
|
176 |
+
},
|
177 |
+
{
|
178 |
+
"cell_type": "code",
|
179 |
+
"execution_count": null,
|
180 |
+
"metadata": {
|
181 |
+
"id": "dO-IbNQHkB3R"
|
182 |
+
},
|
183 |
+
"outputs": [],
|
184 |
+
"source": []
|
185 |
+
}
|
186 |
+
],
|
187 |
+
"metadata": {
|
188 |
+
"accelerator": "GPU",
|
189 |
+
"colab": {
|
190 |
+
"authorship_tag": "ABX9TyN4LuJh6kWhbqxzA5s9sp7k",
|
191 |
+
"collapsed_sections": [],
|
192 |
+
"include_colab_link": true,
|
193 |
+
"machine_shape": "hm",
|
194 |
+
"name": "enhance-me-train.ipynb",
|
195 |
+
"provenance": []
|
196 |
+
},
|
197 |
+
"kernelspec": {
|
198 |
+
"display_name": "Python 3",
|
199 |
+
"name": "python3"
|
200 |
+
},
|
201 |
+
"language_info": {
|
202 |
+
"name": "python"
|
203 |
+
}
|
204 |
+
},
|
205 |
+
"nbformat": 4,
|
206 |
+
"nbformat_minor": 0
|
207 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
black
|
2 |
+
gdown
|
3 |
+
matplotlib
|
4 |
+
streamlit
|
5 |
+
tensorflow
|
6 |
+
tqdm
|
7 |
+
wandb
|