Spaces:
Runtime error
Runtime error
Commit
·
2dd1081
1
Parent(s):
4b640eb
updated mirnet
Browse files- enhance_me/mirnet/mirnet.py +19 -28
enhance_me/mirnet/mirnet.py
CHANGED
@@ -24,48 +24,39 @@ class MIRNet:
|
|
24 |
def __init__(
|
25 |
self,
|
26 |
experiment_name: str,
|
27 |
-
image_size: int = 256,
|
28 |
-
dataset_label: str = "lol",
|
29 |
-
build_datasets: bool = True,
|
30 |
-
val_split: float = 0.2,
|
31 |
-
batch_size: int = 16,
|
32 |
-
apply_random_horizontal_flip: bool = True,
|
33 |
-
apply_random_vertical_flip: bool = True,
|
34 |
-
apply_random_rotation: bool = True,
|
35 |
wandb_api_key=None,
|
36 |
) -> None:
|
37 |
self.experiment_name = experiment_name
|
38 |
-
if dataset_label == "lol":
|
39 |
-
(low_images, enhanced_images), (
|
40 |
-
self.test_low_images,
|
41 |
-
self.test_enhanced_images,
|
42 |
-
) = download_lol_dataset()
|
43 |
-
if build_datasets:
|
44 |
-
self.data_loader = LowLightDataset(
|
45 |
-
image_size=image_size,
|
46 |
-
apply_random_horizontal_flip=apply_random_horizontal_flip,
|
47 |
-
apply_random_vertical_flip=apply_random_vertical_flip,
|
48 |
-
apply_random_rotation=apply_random_rotation,
|
49 |
-
)
|
50 |
-
self._build_datasets(
|
51 |
-
low_images, enhanced_images, val_split=val_split, batch_size=batch_size
|
52 |
-
)
|
53 |
if wandb_api_key is not None:
|
54 |
init_wandb("mirnet", experiment_name, wandb_api_key)
|
55 |
self.using_wandb = True
|
56 |
else:
|
57 |
self.using_wandb = False
|
58 |
|
59 |
-
def
|
60 |
self,
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
63 |
val_split: float = 0.2,
|
64 |
batch_size: int = 16,
|
65 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
(self.train_dataset, self.val_dataset) = self.data_loader.get_datasets(
|
67 |
-
low_light_images=
|
68 |
-
enhanced_images=enhanced_images,
|
69 |
val_split=val_split,
|
70 |
batch_size=batch_size,
|
71 |
)
|
|
|
24 |
def __init__(
|
25 |
self,
|
26 |
experiment_name: str,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
wandb_api_key=None,
|
28 |
) -> None:
|
29 |
self.experiment_name = experiment_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
if wandb_api_key is not None:
|
31 |
init_wandb("mirnet", experiment_name, wandb_api_key)
|
32 |
self.using_wandb = True
|
33 |
else:
|
34 |
self.using_wandb = False
|
35 |
|
36 |
+
def build_datasets(
|
37 |
self,
|
38 |
+
image_size: int = 256,
|
39 |
+
dataset_label: str = "lol",
|
40 |
+
apply_random_horizontal_flip: bool = True,
|
41 |
+
apply_random_vertical_flip: bool = True,
|
42 |
+
apply_random_rotation: bool = True,
|
43 |
val_split: float = 0.2,
|
44 |
batch_size: int = 16,
|
45 |
):
|
46 |
+
if dataset_label == "lol":
|
47 |
+
(self.low_images, self.enhanced_images), (
|
48 |
+
self.test_low_images,
|
49 |
+
self.test_enhanced_images,
|
50 |
+
) = download_lol_dataset()
|
51 |
+
self.data_loader = LowLightDataset(
|
52 |
+
image_size=image_size,
|
53 |
+
apply_random_horizontal_flip=apply_random_horizontal_flip,
|
54 |
+
apply_random_vertical_flip=apply_random_vertical_flip,
|
55 |
+
apply_random_rotation=apply_random_rotation,
|
56 |
+
)
|
57 |
(self.train_dataset, self.val_dataset) = self.data_loader.get_datasets(
|
58 |
+
low_light_images=self.low_images,
|
59 |
+
enhanced_images=self.enhanced_images,
|
60 |
val_split=val_split,
|
61 |
batch_size=batch_size,
|
62 |
)
|