geekyrakshit commited on
Commit
7e6ea19
·
1 Parent(s): c23af12

updated mirnet

Browse files
Files changed (1) hide show
  1. enhance_me/mirnet/mirnet.py +11 -9
enhance_me/mirnet/mirnet.py CHANGED
@@ -26,6 +26,7 @@ class MIRNet:
26
  experiment_name: str,
27
  image_size: int = 256,
28
  dataset_label: str = "lol",
 
29
  val_split: float = 0.2,
30
  batch_size: int = 16,
31
  apply_random_horizontal_flip: bool = True,
@@ -39,15 +40,16 @@ class MIRNet:
39
  self.test_low_images,
40
  self.test_enhanced_images,
41
  ) = download_lol_dataset()
42
- self.data_loader = LowLightDataset(
43
- image_size=image_size,
44
- apply_random_horizontal_flip=apply_random_horizontal_flip,
45
- apply_random_vertical_flip=apply_random_vertical_flip,
46
- apply_random_rotation=apply_random_rotation,
47
- )
48
- self._build_datasets(
49
- low_images, enhanced_images, val_split=val_split, batch_size=batch_size
50
- )
 
51
  if wandb_api_key is not None:
52
  init_wandb("mirnet", experiment_name, wandb_api_key)
53
  self.using_wandb = True
 
26
  experiment_name: str,
27
  image_size: int = 256,
28
  dataset_label: str = "lol",
29
+ build_datasets: bool = True,
30
  val_split: float = 0.2,
31
  batch_size: int = 16,
32
  apply_random_horizontal_flip: bool = True,
 
40
  self.test_low_images,
41
  self.test_enhanced_images,
42
  ) = download_lol_dataset()
43
+ if build_datasets:
44
+ self.data_loader = LowLightDataset(
45
+ image_size=image_size,
46
+ apply_random_horizontal_flip=apply_random_horizontal_flip,
47
+ apply_random_vertical_flip=apply_random_vertical_flip,
48
+ apply_random_rotation=apply_random_rotation,
49
+ )
50
+ self._build_datasets(
51
+ low_images, enhanced_images, val_split=val_split, batch_size=batch_size
52
+ )
53
  if wandb_api_key is not None:
54
  init_wandb("mirnet", experiment_name, wandb_api_key)
55
  self.using_wandb = True