hbpkillerX
commited on
Commit
•
23811f4
1
Parent(s):
e4ea306
Delete train.py
Browse files
train.py
DELETED
@@ -1,114 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import random
|
3 |
-
import numpy as np
|
4 |
-
from glob import glob
|
5 |
-
from PIL import Image, ImageOps
|
6 |
-
import matplotlib.pyplot as plt
|
7 |
-
import tensorflow as tf
|
8 |
-
from tensorflow import keras
|
9 |
-
from tensorflow.keras import layers
|
10 |
-
from model import get_model
|
11 |
-
|
12 |
-
# functions to create the dataset
|
13 |
-
random.seed(1)
|
14 |
-
IMAGE_SIZE = 128
|
15 |
-
BATCH_SIZE = 4
|
16 |
-
MAX_TRAIN_IMAGES = 300
|
17 |
-
|
18 |
-
def autocontrast(tensor, cutoff=0):
|
19 |
-
tensor = tf.cast(tensor, dtype=tf.float32)
|
20 |
-
min_val = tf.reduce_min(tensor)
|
21 |
-
max_val = tf.reduce_max(tensor)
|
22 |
-
range_val = max_val - min_val
|
23 |
-
adjusted_tensor = tf.clip_by_value(tf.cast(tf.round((tensor - min_val - cutoff) * (255 / (range_val - 2 * cutoff))), tf.uint8), 0, 255)
|
24 |
-
return adjusted_tensor
|
25 |
-
|
26 |
-
def read_image(image_path):
|
27 |
-
image = tf.io.read_file(image_path)
|
28 |
-
image = tf.image.decode_png(image, channels=3)
|
29 |
-
image = autocontrast(image)
|
30 |
-
image.set_shape([None, None, 3])
|
31 |
-
image = tf.cast(image, dtype=tf.float32) / 255
|
32 |
-
return image
|
33 |
-
|
34 |
-
|
35 |
-
def random_crop(low_image, enhanced_image):
|
36 |
-
low_image_shape = tf.shape(low_image)[:2]
|
37 |
-
low_w = tf.random.uniform(
|
38 |
-
shape=(), maxval=low_image_shape[1] - IMAGE_SIZE + 1, dtype=tf.int32
|
39 |
-
)
|
40 |
-
low_h = tf.random.uniform(
|
41 |
-
shape=(), maxval=low_image_shape[0] - IMAGE_SIZE + 1, dtype=tf.int32
|
42 |
-
)
|
43 |
-
enhanced_w = low_w
|
44 |
-
enhanced_h = low_h
|
45 |
-
low_image_cropped = low_image[
|
46 |
-
low_h : low_h + IMAGE_SIZE, low_w : low_w + IMAGE_SIZE
|
47 |
-
]
|
48 |
-
enhanced_image_cropped = enhanced_image[
|
49 |
-
enhanced_h : enhanced_h + IMAGE_SIZE, enhanced_w : enhanced_w + IMAGE_SIZE
|
50 |
-
]
|
51 |
-
return low_image_cropped, enhanced_image_cropped
|
52 |
-
|
53 |
-
|
54 |
-
def load_data(low_light_image_path, enhanced_image_path):
|
55 |
-
low_light_image = read_image(low_light_image_path)
|
56 |
-
enhanced_image = read_image(enhanced_image_path)
|
57 |
-
low_light_image, enhanced_image = random_crop(low_light_image, enhanced_image)
|
58 |
-
return low_light_image, enhanced_image
|
59 |
-
|
60 |
-
|
61 |
-
def get_dataset(low_light_images, enhanced_images):
|
62 |
-
dataset = tf.data.Dataset.from_tensor_slices((low_light_images, enhanced_images))
|
63 |
-
dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
|
64 |
-
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
|
65 |
-
return dataset
|
66 |
-
|
67 |
-
# Loss functions
|
68 |
-
|
69 |
-
class CustomLoss:
|
70 |
-
def __init__(self, perceptual_loss_model):
|
71 |
-
self.perceptual_loss_model = perceptual_loss_model
|
72 |
-
def perceptual_loss(self, y_true, y_pred):
|
73 |
-
y_true_features = self.perceptual_loss_model(y_true)
|
74 |
-
y_pred_features = self.perceptual_loss_model(y_pred)
|
75 |
-
loss = tf.reduce_mean(tf.square(y_true_features[0] - y_pred_features[0])) + tf.reduce_mean(tf.square(y_true_features[1] - y_pred_features[1]))
|
76 |
-
return loss
|
77 |
-
def charbonnier_loss(self, y_true, y_pred):
|
78 |
-
return tf.reduce_mean(tf.sqrt(tf.square(y_true - y_pred) + tf.square(1e-3)))
|
79 |
-
def __call__(self, y_true, y_pred):
|
80 |
-
return 0.5*self.perceptual_loss(y_true, y_pred) + 0.4*self.charbonnier_loss(y_true, y_pred)
|
81 |
-
|
82 |
-
def peak_signal_noise_ratio(y_true, y_pred):
|
83 |
-
return tf.image.psnr(y_pred, y_true, max_val=255.0)
|
84 |
-
|
85 |
-
def main():
|
86 |
-
train_low_light_images = sorted(glob("./lol_dataset/our485/low/*"))[:MAX_TRAIN_IMAGES]
|
87 |
-
train_enhanced_images = sorted(glob("./lol_dataset/our485/high/*"))[:MAX_TRAIN_IMAGES]
|
88 |
-
|
89 |
-
val_low_light_images = sorted(glob("./lol_dataset/our485/low/*"))[MAX_TRAIN_IMAGES:]
|
90 |
-
val_enhanced_images = sorted(glob("./lol_dataset/our485/high/*"))[MAX_TRAIN_IMAGES:]
|
91 |
-
|
92 |
-
train_dataset = get_dataset(train_low_light_images, train_enhanced_images)
|
93 |
-
val_dataset = get_dataset(val_low_light_images, val_enhanced_images)
|
94 |
-
|
95 |
-
#Model for calculating perceptual loss
|
96 |
-
vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet')
|
97 |
-
for layer in vgg.layers:
|
98 |
-
layer.trainable = False #Freeze all the layers, since this model is for evaluation only
|
99 |
-
outputs = [vgg.get_layer('block3_conv3').output, vgg.get_layer('block4_conv3').output]
|
100 |
-
perceptual_loss_model = tf.keras.models.Model(inputs=vgg.input, outputs=outputs)
|
101 |
-
|
102 |
-
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
|
103 |
-
loss = CustomLoss(perceptual_loss_model)
|
104 |
-
model = get_model()
|
105 |
-
|
106 |
-
model.compile(
|
107 |
-
optimizer=optimizer, loss=loss, metrics=[peak_signal_noise_ratio]
|
108 |
-
)
|
109 |
-
|
110 |
-
history = model.fit(train_dataset, validation_data=val_dataset, epochs=50)
|
111 |
-
model.save_weights("model.h5")
|
112 |
-
|
113 |
-
if __name__ == "__main__":
|
114 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|