geekyrakshit commited on
Commit
2dd1081
·
1 Parent(s): 4b640eb

updated mirnet

Browse files
Files changed (1) hide show
  1. enhance_me/mirnet/mirnet.py +19 -28
enhance_me/mirnet/mirnet.py CHANGED
@@ -24,48 +24,39 @@ class MIRNet:
24
  def __init__(
25
  self,
26
  experiment_name: str,
27
- image_size: int = 256,
28
- dataset_label: str = "lol",
29
- build_datasets: bool = True,
30
- val_split: float = 0.2,
31
- batch_size: int = 16,
32
- apply_random_horizontal_flip: bool = True,
33
- apply_random_vertical_flip: bool = True,
34
- apply_random_rotation: bool = True,
35
  wandb_api_key=None,
36
  ) -> None:
37
  self.experiment_name = experiment_name
38
- if dataset_label == "lol":
39
- (low_images, enhanced_images), (
40
- self.test_low_images,
41
- self.test_enhanced_images,
42
- ) = download_lol_dataset()
43
- if build_datasets:
44
- self.data_loader = LowLightDataset(
45
- image_size=image_size,
46
- apply_random_horizontal_flip=apply_random_horizontal_flip,
47
- apply_random_vertical_flip=apply_random_vertical_flip,
48
- apply_random_rotation=apply_random_rotation,
49
- )
50
- self._build_datasets(
51
- low_images, enhanced_images, val_split=val_split, batch_size=batch_size
52
- )
53
  if wandb_api_key is not None:
54
  init_wandb("mirnet", experiment_name, wandb_api_key)
55
  self.using_wandb = True
56
  else:
57
  self.using_wandb = False
58
 
59
- def _build_datasets(
60
  self,
61
- low_light_images: List[str],
62
- enhanced_images: List[str],
 
 
 
63
  val_split: float = 0.2,
64
  batch_size: int = 16,
65
  ):
 
 
 
 
 
 
 
 
 
 
 
66
  (self.train_dataset, self.val_dataset) = self.data_loader.get_datasets(
67
- low_light_images=low_light_images,
68
- enhanced_images=enhanced_images,
69
  val_split=val_split,
70
  batch_size=batch_size,
71
  )
 
24
  def __init__(
25
  self,
26
  experiment_name: str,
 
 
 
 
 
 
 
 
27
  wandb_api_key=None,
28
  ) -> None:
29
  self.experiment_name = experiment_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  if wandb_api_key is not None:
31
  init_wandb("mirnet", experiment_name, wandb_api_key)
32
  self.using_wandb = True
33
  else:
34
  self.using_wandb = False
35
 
36
+ def build_datasets(
37
  self,
38
+ image_size: int = 256,
39
+ dataset_label: str = "lol",
40
+ apply_random_horizontal_flip: bool = True,
41
+ apply_random_vertical_flip: bool = True,
42
+ apply_random_rotation: bool = True,
43
  val_split: float = 0.2,
44
  batch_size: int = 16,
45
  ):
46
+ if dataset_label == "lol":
47
+ (self.low_images, self.enhanced_images), (
48
+ self.test_low_images,
49
+ self.test_enhanced_images,
50
+ ) = download_lol_dataset()
51
+ self.data_loader = LowLightDataset(
52
+ image_size=image_size,
53
+ apply_random_horizontal_flip=apply_random_horizontal_flip,
54
+ apply_random_vertical_flip=apply_random_vertical_flip,
55
+ apply_random_rotation=apply_random_rotation,
56
+ )
57
  (self.train_dataset, self.val_dataset) = self.data_loader.get_datasets(
58
+ low_light_images=self.low_images,
59
+ enhanced_images=self.enhanced_images,
60
  val_split=val_split,
61
  batch_size=batch_size,
62
  )