hbpkillerX commited on
Commit
62962ec
1 Parent(s): 50e9694

Upload 4 files

Browse files
Files changed (4) hide show
  1. eval.py +62 -0
  2. model.h5 +3 -0
  3. model.py +36 -0
  4. train.py +114 -0
eval.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import numpy as np
4
+ from keras.models import load_model
5
+ from PIL import Image
6
+ import tensorflow as tf
7
+ from model import get_model
8
+
9
+ IMAGE_SIZE = 128
10
+
11
+ def autocontrast(tensor, cutoff=0):
12
+ tensor = tf.cast(tensor, dtype=tf.float32)
13
+ min_val = tf.reduce_min(tensor)
14
+ max_val = tf.reduce_max(tensor)
15
+ range_val = max_val - min_val
16
+ adjusted_tensor = tf.clip_by_value(tf.cast(tf.round((tensor - min_val - cutoff) * (255 / (range_val - 2 * cutoff))), tf.uint8), 0, 255)
17
+ return adjusted_tensor
18
+
19
+ def read_image(image_path):
20
+ image = tf.io.read_file(image_path)
21
+ image = tf.image.decode_png(image, channels=3)
22
+ image = autocontrast(image)
23
+ image.set_shape([None, None, 3])
24
+ image = tf.cast(image, dtype=tf.float32) / 255
25
+ return image
26
+
27
+ def load_data(low_light_image_path):
28
+ low_light_image = read_image(low_light_image_path)
29
+ return low_light_image
30
+
31
+ def evaluate_model(images_path, eval_path, model_path):
32
+ model = get_model()
33
+ model.load_weights("./model.h5")
34
+
35
+ image_files = [f for f in os.listdir(images_path) if f.endswith(".jpg") or f.endswith(".png")]
36
+
37
+ for image_file in image_files:
38
+ image_path = os.path.join(images_path, image_file)
39
+ image = load_data(image_path)
40
+ image = np.expand_dims(image, axis=0)
41
+
42
+ generated_image = model.predict(image)
43
+
44
+ eval_file = f"eval_{image_file}"
45
+ eval_file_path = os.path.join(eval_path, eval_file)
46
+ generated_image = np.squeeze(generated_image, axis=0)
47
+ generated_image = np.clip(generated_image * 255, 0, 255).astype(np.uint8) # Clip values and convert to uint8
48
+ generated_image = Image.fromarray(generated_image)
49
+ generated_image.save(eval_file_path)
50
+ print(f"Generated image saved at: {eval_file_path}")
51
+
52
+ def main():
53
+ parser = argparse.ArgumentParser(description="Evaluate model on images")
54
+ parser.add_argument("images_path", type=str, help="Path to images directory")
55
+ parser.add_argument("eval_path", type=str, help="Path to evaluation directory")
56
+ parser.add_argument("--model_path", type=str, default="./model.h5", help="Path to model")
57
+ args = parser.parse_args()
58
+ os.makedirs(args.eval_path, exist_ok=True)
59
+ evaluate_model(args.images_path, args.eval_path, args.model_path)
60
+
61
+ if __name__ == "__main__":
62
+ main()
model.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ab872a7258a78d9b5044e8fbb98ef0439446e7a9e8388e2dd5aaebc04eb4fed
3
+ size 6903640
model.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ def residual_block(inputs, filters):
4
+ x = tf.keras.layers.Conv2D(filters, (3, 3), padding='same', activation='relu')(inputs)
5
+ x = tf.keras.layers.Conv2D(filters, (3, 3), padding='same')(x)
6
+ x = tf.keras.layers.add([inputs, x])
7
+ x = tf.keras.layers.Activation('relu')(x)
8
+ return x
9
+
10
+ def get_model():
11
+ inputs = tf.keras.layers.Input(shape=(None, None, 3))
12
+ batch_size = tf.shape(inputs)[0]
13
+
14
+ conv1 = tf.keras.layers.Conv2D(32, (3, 3), padding='same', activation='relu')(inputs)
15
+ conv1 = tf.keras.layers.Conv2D(32, (3, 3), padding='same', activation='relu')(conv1)
16
+
17
+ conv2 = tf.keras.layers.Conv2D(64, (3, 3), padding='same', activation='relu')(conv1)
18
+ conv2 = tf.keras.layers.Conv2D(64, (3, 3), padding='same', activation='relu')(conv2)
19
+
20
+ conv3 = tf.keras.layers.Conv2D(128, (3, 3), padding='same', activation='relu')(conv2)
21
+ conv3 = tf.keras.layers.Conv2D(128, (3, 3), padding='same', activation='relu')(conv2)
22
+
23
+ res1 = residual_block(conv3, 128)
24
+ res2 = residual_block(res1, 128)
25
+ res3 = residual_block(res2, 128)
26
+ res4 = residual_block(res3, 128)
27
+ res5 = residual_block(res4, 128)
28
+
29
+ deconv1 = tf.keras.layers.Conv2DTranspose(64, (3, 3), padding='same', activation='relu')(res5)
30
+ deconv2 = tf.keras.layers.Conv2DTranspose(32, (3, 3), padding='same', activation='relu')(deconv1)
31
+
32
+ outputs = tf.keras.layers.Conv2D(3, (3, 3), padding='same', activation='sigmoid')(deconv2)
33
+ outputs=tf.keras.layers.add([inputs, outputs])
34
+
35
+ model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
36
+ return model
train.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import numpy as np
4
+ from glob import glob
5
+ from PIL import Image, ImageOps
6
+ import matplotlib.pyplot as plt
7
+ import tensorflow as tf
8
+ from tensorflow import keras
9
+ from tensorflow.keras import layers
10
+ from model import get_model
11
+
12
+ # functions to create the dataset
13
+ random.seed(1)
14
+ IMAGE_SIZE = 128
15
+ BATCH_SIZE = 4
16
+ MAX_TRAIN_IMAGES = 300
17
+
18
+ def autocontrast(tensor, cutoff=0):
19
+ tensor = tf.cast(tensor, dtype=tf.float32)
20
+ min_val = tf.reduce_min(tensor)
21
+ max_val = tf.reduce_max(tensor)
22
+ range_val = max_val - min_val
23
+ adjusted_tensor = tf.clip_by_value(tf.cast(tf.round((tensor - min_val - cutoff) * (255 / (range_val - 2 * cutoff))), tf.uint8), 0, 255)
24
+ return adjusted_tensor
25
+
26
+ def read_image(image_path):
27
+ image = tf.io.read_file(image_path)
28
+ image = tf.image.decode_png(image, channels=3)
29
+ image = autocontrast(image)
30
+ image.set_shape([None, None, 3])
31
+ image = tf.cast(image, dtype=tf.float32) / 255
32
+ return image
33
+
34
+
35
+ def random_crop(low_image, enhanced_image):
36
+ low_image_shape = tf.shape(low_image)[:2]
37
+ low_w = tf.random.uniform(
38
+ shape=(), maxval=low_image_shape[1] - IMAGE_SIZE + 1, dtype=tf.int32
39
+ )
40
+ low_h = tf.random.uniform(
41
+ shape=(), maxval=low_image_shape[0] - IMAGE_SIZE + 1, dtype=tf.int32
42
+ )
43
+ enhanced_w = low_w
44
+ enhanced_h = low_h
45
+ low_image_cropped = low_image[
46
+ low_h : low_h + IMAGE_SIZE, low_w : low_w + IMAGE_SIZE
47
+ ]
48
+ enhanced_image_cropped = enhanced_image[
49
+ enhanced_h : enhanced_h + IMAGE_SIZE, enhanced_w : enhanced_w + IMAGE_SIZE
50
+ ]
51
+ return low_image_cropped, enhanced_image_cropped
52
+
53
+
54
+ def load_data(low_light_image_path, enhanced_image_path):
55
+ low_light_image = read_image(low_light_image_path)
56
+ enhanced_image = read_image(enhanced_image_path)
57
+ low_light_image, enhanced_image = random_crop(low_light_image, enhanced_image)
58
+ return low_light_image, enhanced_image
59
+
60
+
61
+ def get_dataset(low_light_images, enhanced_images):
62
+ dataset = tf.data.Dataset.from_tensor_slices((low_light_images, enhanced_images))
63
+ dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
64
+ dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
65
+ return dataset
66
+
67
+ # Loss functions
68
+
69
+ class CustomLoss:
70
+ def __init__(self, perceptual_loss_model):
71
+ self.perceptual_loss_model = perceptual_loss_model
72
+ def perceptual_loss(self, y_true, y_pred):
73
+ y_true_features = self.perceptual_loss_model(y_true)
74
+ y_pred_features = self.perceptual_loss_model(y_pred)
75
+ loss = tf.reduce_mean(tf.square(y_true_features[0] - y_pred_features[0])) + tf.reduce_mean(tf.square(y_true_features[1] - y_pred_features[1]))
76
+ return loss
77
+ def charbonnier_loss(self, y_true, y_pred):
78
+ return tf.reduce_mean(tf.sqrt(tf.square(y_true - y_pred) + tf.square(1e-3)))
79
+ def __call__(self, y_true, y_pred):
80
+ return 0.5*self.perceptual_loss(y_true, y_pred) + 0.4*self.charbonnier_loss(y_true, y_pred)
81
+
82
+ def peak_signal_noise_ratio(y_true, y_pred):
83
+ return tf.image.psnr(y_pred, y_true, max_val=255.0)
84
+
85
+ def main():
86
+ train_low_light_images = sorted(glob("./lol_dataset/our485/low/*"))[:MAX_TRAIN_IMAGES]
87
+ train_enhanced_images = sorted(glob("./lol_dataset/our485/high/*"))[:MAX_TRAIN_IMAGES]
88
+
89
+ val_low_light_images = sorted(glob("./lol_dataset/our485/low/*"))[MAX_TRAIN_IMAGES:]
90
+ val_enhanced_images = sorted(glob("./lol_dataset/our485/high/*"))[MAX_TRAIN_IMAGES:]
91
+
92
+ train_dataset = get_dataset(train_low_light_images, train_enhanced_images)
93
+ val_dataset = get_dataset(val_low_light_images, val_enhanced_images)
94
+
95
+ #Model for calculating perceptual loss
96
+ vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet')
97
+ for layer in vgg.layers:
98
+ layer.trainable = False #Freeze all the layers, since this model is for evaluation only
99
+ outputs = [vgg.get_layer('block3_conv3').output, vgg.get_layer('block4_conv3').output]
100
+ perceptual_loss_model = tf.keras.models.Model(inputs=vgg.input, outputs=outputs)
101
+
102
+ optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
103
+ loss = CustomLoss(perceptual_loss_model)
104
+ model = get_model()
105
+
106
+ model.compile(
107
+ optimizer=optimizer, loss=loss, metrics=[peak_signal_noise_ratio]
108
+ )
109
+
110
+ history = model.fit(train_dataset, validation_data=val_dataset, epochs=50)
111
+ model.save_weights("model.h5")
112
+
113
+ if __name__ == "__main__":
114
+ main()