Spaces:
Runtime error
Runtime error
Commit
·
192c48a
1
Parent(s):
c8d52e7
added download function for lol dataset
Browse files- .gitignore +3 -0
- enhance_me/commons.py +22 -2
- enhance_me/mirnet/mirnet.py +9 -1
.gitignore
CHANGED
|
@@ -127,3 +127,6 @@ dmypy.json
|
|
| 127 |
|
| 128 |
# Pyre type checker
|
| 129 |
.pyre/
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
# Pyre type checker
|
| 129 |
.pyre/
|
| 130 |
+
|
| 131 |
+
# Datasets
|
| 132 |
+
datasets/
|
enhance_me/commons.py
CHANGED
|
@@ -1,8 +1,11 @@
|
|
| 1 |
import os
|
| 2 |
import wandb
|
| 3 |
-
|
| 4 |
import matplotlib.pyplot as plt
|
| 5 |
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
def read_image(image_path):
|
| 8 |
image = tf.io.read_file(image_path)
|
|
@@ -39,5 +42,22 @@ def closest_number(n, m):
|
|
| 39 |
|
| 40 |
def init_wandb(project_name, experiment_name, wandb_api_key):
|
| 41 |
if project_name is not None and experiment_name is not None:
|
| 42 |
-
os.environ[
|
| 43 |
wandb.init(project=project_name, name=experiment_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import wandb
|
| 3 |
+
from glob import glob
|
| 4 |
import matplotlib.pyplot as plt
|
| 5 |
|
| 6 |
+
import tensorflow as tf
|
| 7 |
+
from tensorflow.keras import utils
|
| 8 |
+
|
| 9 |
|
| 10 |
def read_image(image_path):
|
| 11 |
image = tf.io.read_file(image_path)
|
|
|
|
| 42 |
|
| 43 |
def init_wandb(project_name, experiment_name, wandb_api_key):
|
| 44 |
if project_name is not None and experiment_name is not None:
|
| 45 |
+
os.environ["WANDB_API_KEY"] = wandb_api_key
|
| 46 |
wandb.init(project=project_name, name=experiment_name)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def download_lol_dataset():
|
| 50 |
+
utils.get_file(
|
| 51 |
+
"lol_dataset.zip",
|
| 52 |
+
"https://github.com/soumik12345/enhance-me/releases/download/v0.1/lol_dataset.zip",
|
| 53 |
+
cache_dir="./",
|
| 54 |
+
cache_subdir="./datasets",
|
| 55 |
+
extract=True,
|
| 56 |
+
)
|
| 57 |
+
low_images = sorted(glob("./datasets/lol_dataset/our485/low/*"))
|
| 58 |
+
enhanced_images = sorted(glob("./datasets/lol_dataset/our485/high/*"))
|
| 59 |
+
assert len(low_images) == len(enhanced_images)
|
| 60 |
+
test_low_images = sorted(glob("./datasets/lol_dataset/eval15/low/*"))
|
| 61 |
+
test_enhanced_images = sorted(glob("./datasets/lol_dataset/eval15/high/*"))
|
| 62 |
+
assert len(test_low_images) == len(test_enhanced_images)
|
| 63 |
+
return (low_images, enhanced_images), (test_low_images, test_enhanced_images)
|
enhance_me/mirnet/mirnet.py
CHANGED
|
@@ -12,7 +12,12 @@ from wandb.keras import WandbCallback
|
|
| 12 |
from .dataloader import LowLightDataset
|
| 13 |
from .models import build_mirnet_model
|
| 14 |
from .losses import CharbonnierLoss
|
| 15 |
-
from ..commons import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
class MIRNet:
|
|
@@ -20,12 +25,15 @@ class MIRNet:
|
|
| 20 |
self,
|
| 21 |
experiment_name: str,
|
| 22 |
image_size: int = 256,
|
|
|
|
| 23 |
apply_random_horizontal_flip: bool = True,
|
| 24 |
apply_random_vertical_flip: bool = True,
|
| 25 |
apply_random_rotation: bool = True,
|
| 26 |
wandb_api_key=None,
|
| 27 |
) -> None:
|
| 28 |
self.experiment_name = experiment_name
|
|
|
|
|
|
|
| 29 |
self.data_loader = LowLightDataset(
|
| 30 |
image_size=image_size,
|
| 31 |
apply_random_horizontal_flip=apply_random_horizontal_flip,
|
|
|
|
| 12 |
from .dataloader import LowLightDataset
|
| 13 |
from .models import build_mirnet_model
|
| 14 |
from .losses import CharbonnierLoss
|
| 15 |
+
from ..commons import (
|
| 16 |
+
peak_signal_noise_ratio,
|
| 17 |
+
closest_number,
|
| 18 |
+
init_wandb,
|
| 19 |
+
download_lol_dataset,
|
| 20 |
+
)
|
| 21 |
|
| 22 |
|
| 23 |
class MIRNet:
|
|
|
|
| 25 |
self,
|
| 26 |
experiment_name: str,
|
| 27 |
image_size: int = 256,
|
| 28 |
+
dataset_label: str = "lol",
|
| 29 |
apply_random_horizontal_flip: bool = True,
|
| 30 |
apply_random_vertical_flip: bool = True,
|
| 31 |
apply_random_rotation: bool = True,
|
| 32 |
wandb_api_key=None,
|
| 33 |
) -> None:
|
| 34 |
self.experiment_name = experiment_name
|
| 35 |
+
if dataset_label == "lol":
|
| 36 |
+
download_lol_dataset()
|
| 37 |
self.data_loader = LowLightDataset(
|
| 38 |
image_size=image_size,
|
| 39 |
apply_random_horizontal_flip=apply_random_horizontal_flip,
|