Anuj-Panthri commited on
Commit
34eb6c0
·
1 Parent(s): 6216ecd

added train,visualize_results scripts

Browse files
configs/experiment1.yaml CHANGED
@@ -6,9 +6,11 @@ model: model_v1
6
  # common parameters
7
  seed: 324
8
  train_size: 0.8
9
- image_size: 224
 
10
  shuffle: False
11
 
12
  # training related
13
  batch_size: 16
14
- epochs: 10
 
 
6
  # common parameters
7
  seed: 324
8
  train_size: 0.8
9
+ # image_size: 224
10
+ image_size: 64
11
  shuffle: False
12
 
13
  # training related
14
  batch_size: 16
15
+ # epochs: 10
16
+ epochs: 02
src/scripts/{create_sub_task.py → create_task.py} RENAMED
@@ -5,9 +5,9 @@ def create_file(file_path,file_content):
5
  with open(file_path,"w") as f:
6
  f.write(file_content)
7
 
8
- def create_data(data_dir,dataset_name,sub_task_dir):
9
- # call src/sub_task/scripts/create_dataset.py dataset_name
10
- os.system(f"python {sub_task_dir}/scripts/create_dataset.py {dataset_name}")
11
 
12
  register_datasets_file_path = os.path.join(data_dir,"register_datasets.py")
13
  create_file(register_datasets_file_path,
@@ -19,7 +19,7 @@ datasets = ["{dataset_name}"]
19
 
20
 
21
 
22
- def create_model(model_dir:str, model_name:str, sub_task_dir:str):
23
  base_model_interface_path = os.path.join(model_dir,"base_model_interface.py")
24
 
25
  create_file(base_model_interface_path,
@@ -27,26 +27,34 @@ def create_model(model_dir:str, model_name:str, sub_task_dir:str):
27
  from abc import ABC, abstractmethod
28
 
29
  # BaseModel Abstract class
30
- # all the models within this sub_task must inherit this class
31
 
32
  class BaseModel(ABC):
33
  @abstractmethod
34
  def train(self):
35
  pass
36
 
 
 
 
 
37
  @abstractmethod
38
  def predict(self,inputs):
39
  pass
 
 
 
 
40
  """)
41
 
42
 
43
- # call src/sub_task/scripts/create_model.py model_name
44
- os.system(f"python {sub_task_dir}/scripts/create_model.py {model_name}")
45
 
46
 
47
  register_models_path = os.path.join(model_dir,"register_models.py")
48
  create_file(register_models_path,
49
- f"""# register models of this sub_task here
50
  models = ["{model_name}"]
51
  """)
52
 
@@ -73,7 +81,7 @@ models = ["{model_name}"]
73
  """)
74
 
75
 
76
- def create_scripts(scripts_dir,sub_task):
77
  create_dataset_path = os.path.join(scripts_dir,"create_dataset.py")
78
  create_file(create_dataset_path,
79
  f"""import os,shutil
@@ -86,7 +94,7 @@ def create_file(file_path,file_content):
86
  def create_dataset(args):
87
  dataset_name = args.name
88
  force_flag = args.force
89
- datasets_dir = os.path.join('src','{sub_task}','data','datasets')
90
 
91
  os.makedirs(datasets_dir,exist_ok=True)
92
  dataset_path = os.path.join(datasets_dir,dataset_name+".py")
@@ -157,7 +165,7 @@ def create_file(file_path,file_content):
157
  def create_model(args):
158
  model_name = args.name
159
  force_flag = args.force
160
- models_dir = os.path.join('src','{sub_task}','model',"models")
161
  os.makedirs(models_dir,exist_ok=True)
162
  model_path = os.path.join(models_dir,model_name+".py")
163
 
@@ -173,7 +181,7 @@ def create_model(args):
173
 
174
  model_name_camel_case = "".join([part.capitalize() for part in model_name.split("_")])
175
  create_file(model_path,
176
- f\"\"\"from src.{sub_task}.model.base_model_interface import BaseModel
177
 
178
  class Model(BaseModel):
179
  def train(self):
@@ -197,35 +205,35 @@ if __name__=="__main__":
197
 
198
 
199
 
200
- def create_sub_task(args):
201
- """Used to create sub_task within our main task"""
202
- sub_task = args.sub_task
203
  force_flag = args.force
204
  dataset_name = "dataset1"
205
  model_name = "model1"
206
 
207
- sub_task_dir = os.path.join('src',sub_task)
208
- data_dir = os.path.join(sub_task_dir,'data')
209
- model_dir = os.path.join(sub_task_dir,'model')
210
- scripts_dir = os.path.join(sub_task_dir,"scripts")
211
  # print(scripts_dir)
212
- # deleted old sub_task if force flag exists and sub_task already exists
213
- if os.path.exists(sub_task_dir):
214
  if force_flag:
215
- print("Replacing existing sub_task:",sub_task)
216
- shutil.rmtree(sub_task_dir)
217
  else:
218
- print(f"{sub_task} already exists, use --force flag if you want to reset it to default")
219
  exit()
220
 
221
  # create empty folders
222
- os.makedirs(sub_task_dir,exist_ok=True)
223
  os.makedirs(data_dir,exist_ok=True)
224
  os.makedirs(model_dir,exist_ok=True)
225
  os.makedirs(scripts_dir,exist_ok=True)
226
 
227
  # make config validator file
228
- validate_config_file_path = os.path.join(sub_task_dir,"validate_config.py")
229
  create_file(validate_config_file_path,
230
  '''# from cerberus import Validator
231
 
@@ -252,22 +260,22 @@ schema = {
252
  ''')
253
 
254
  # make scripts files
255
- create_scripts(scripts_dir,sub_task)
256
 
257
  # make data files
258
- create_data(data_dir,dataset_name,sub_task_dir)
259
 
260
  # make model files
261
- create_model(model_dir,model_name,sub_task_dir)
262
 
263
 
264
  def main():
265
- parser = argparse.ArgumentParser(description="Create blueprint sub_task")
266
- parser.add_argument('sub_task',type=str,help="sub_task of project (e.g., simple_regression_colorization)")
267
- parser.add_argument("--force",action="store_true",help="forcefully replace old existing sub_task to default",default=False)
268
  args = parser.parse_args()
269
 
270
- create_sub_task(args)
271
 
272
  if __name__=="__main__":
273
  main()
 
5
  with open(file_path,"w") as f:
6
  f.write(file_content)
7
 
8
+ def create_data(data_dir,dataset_name,task_dir):
9
+ # call src/task/scripts/create_dataset.py dataset_name
10
+ os.system(f"python {task_dir}/scripts/create_dataset.py {dataset_name}")
11
 
12
  register_datasets_file_path = os.path.join(data_dir,"register_datasets.py")
13
  create_file(register_datasets_file_path,
 
19
 
20
 
21
 
22
+ def create_model(model_dir:str, model_name:str, task_dir:str):
23
  base_model_interface_path = os.path.join(model_dir,"base_model_interface.py")
24
 
25
  create_file(base_model_interface_path,
 
27
  from abc import ABC, abstractmethod
28
 
29
  # BaseModel Abstract class
30
+ # all the models within this task must inherit this class
31
 
32
  class BaseModel(ABC):
33
  @abstractmethod
34
  def train(self):
35
  pass
36
 
37
+ @abstractmethod
38
+ def evaluate(self):
39
+ pass
40
+
41
  @abstractmethod
42
  def predict(self,inputs):
43
  pass
44
+
45
+ @abstractmethod
46
+ def show_results(self):
47
+ pass
48
  """)
49
 
50
 
51
+ # call src/task/scripts/create_model.py model_name
52
+ os.system(f"python {task_dir}/scripts/create_model.py {model_name}")
53
 
54
 
55
  register_models_path = os.path.join(model_dir,"register_models.py")
56
  create_file(register_models_path,
57
+ f"""# register models of this task here
58
  models = ["{model_name}"]
59
  """)
60
 
 
81
  """)
82
 
83
 
84
+ def create_scripts(scripts_dir,task):
85
  create_dataset_path = os.path.join(scripts_dir,"create_dataset.py")
86
  create_file(create_dataset_path,
87
  f"""import os,shutil
 
94
  def create_dataset(args):
95
  dataset_name = args.name
96
  force_flag = args.force
97
+ datasets_dir = os.path.join('src','{task}','data','datasets')
98
 
99
  os.makedirs(datasets_dir,exist_ok=True)
100
  dataset_path = os.path.join(datasets_dir,dataset_name+".py")
 
165
  def create_model(args):
166
  model_name = args.name
167
  force_flag = args.force
168
+ models_dir = os.path.join('src','{task}','model',"models")
169
  os.makedirs(models_dir,exist_ok=True)
170
  model_path = os.path.join(models_dir,model_name+".py")
171
 
 
181
 
182
  model_name_camel_case = "".join([part.capitalize() for part in model_name.split("_")])
183
  create_file(model_path,
184
+ f\"\"\"from src.{task}.model.base_model_interface import BaseModel
185
 
186
  class Model(BaseModel):
187
  def train(self):
 
205
 
206
 
207
 
208
+ def create_task(args):
209
+ """Used to create task within our main task"""
210
+ task = args.task
211
  force_flag = args.force
212
  dataset_name = "dataset1"
213
  model_name = "model1"
214
 
215
+ task_dir = os.path.join('src',task)
216
+ data_dir = os.path.join(task_dir,'data')
217
+ model_dir = os.path.join(task_dir,'model')
218
+ scripts_dir = os.path.join(task_dir,"scripts")
219
  # print(scripts_dir)
220
+ # deleted old task if force flag exists and task already exists
221
+ if os.path.exists(task_dir):
222
  if force_flag:
223
+ print("Replacing existing task:",task)
224
+ shutil.rmtree(task_dir)
225
  else:
226
+ print(f"{task} already exists, use --force flag if you want to reset it to default")
227
  exit()
228
 
229
  # create empty folders
230
+ os.makedirs(task_dir,exist_ok=True)
231
  os.makedirs(data_dir,exist_ok=True)
232
  os.makedirs(model_dir,exist_ok=True)
233
  os.makedirs(scripts_dir,exist_ok=True)
234
 
235
  # make config validator file
236
+ validate_config_file_path = os.path.join(task_dir,"validate_config.py")
237
  create_file(validate_config_file_path,
238
  '''# from cerberus import Validator
239
 
 
260
  ''')
261
 
262
  # make scripts files
263
+ create_scripts(scripts_dir,task)
264
 
265
  # make data files
266
+ create_data(data_dir,dataset_name,task_dir)
267
 
268
  # make model files
269
+ create_model(model_dir,model_name,task_dir)
270
 
271
 
272
  def main():
273
+ parser = argparse.ArgumentParser(description="Create blueprint task")
274
+ parser.add_argument('task',type=str,help="task of project (e.g., simple_regression_colorization)")
275
+ parser.add_argument("--force",action="store_true",help="forcefully replace old existing task to default",default=False)
276
  args = parser.parse_args()
277
 
278
+ create_task(args)
279
 
280
  if __name__=="__main__":
281
  main()
src/scripts/train.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from src.utils.config_loader import Config
4
+ from src.utils import config_loader
5
+ from src.utils.script_utils import validate_config
6
+ import importlib
7
+ from pathlib import Path
8
+
9
+
10
+ def train(args):
11
+ config_file_path = args.config_file
12
+ config = Config(config_file_path)
13
+
14
+ # validate config
15
+ validate_config(config)
16
+
17
+ # set config globally
18
+ config_loader.config = config
19
+
20
+ # now visualize the dataset
21
+ Model = importlib.import_module(f"src.{config.task}.model.models.{config.model}").Model
22
+
23
+
24
+ model_dir = os.path.join("models",config.task,config.model)
25
+ os.makedirs(model_dir,exist_ok=True)
26
+ model_save_path = os.path.join(model_dir,"model.weights.h5")
27
+
28
+ model = Model()
29
+ model.train()
30
+ model.save(model_save_path)
31
+ metrics = model.evaluate()
32
+ print("Model Evaluation Metrics:",metrics)
33
+
34
+ def main():
35
+ parser = argparse.ArgumentParser(description="train model based on config yaml file")
36
+ parser.add_argument("config_file",type=str)
37
+ args = parser.parse_args()
38
+ train(args)
39
+
40
+ if __name__=="__main__":
41
+ main()
src/scripts/visualize_results.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from src.utils.config_loader import Config
4
+ from src.utils import config_loader
5
+ from src.utils.script_utils import validate_config
6
+ import importlib
7
+
8
+
9
+ def visualize_dataset(args):
10
+ config_file_path = args.config_file
11
+ config = Config(config_file_path)
12
+
13
+ # validate config
14
+ validate_config(config)
15
+
16
+ # set config globally
17
+ config_loader.config = config
18
+
19
+ # now load model and visualize the results
20
+ model_dir = os.path.join("models",config.task,config.model)
21
+ model_save_path = os.path.join(model_dir,"model.weights.h5")
22
+
23
+ if not os.path.exists(model_save_path):
24
+ raise Exception("No model found:","first use train.py to train and export a model")
25
+
26
+ Model = importlib.import_module(f"src.{config.task}.model.models.{config.model}").Model
27
+ model = Model(model_save_path)
28
+
29
+ # model.train_ds
30
+ model.show_results()
31
+
32
+
33
+
34
+ def main():
35
+ parser = argparse.ArgumentParser(description="Prepare dataset based on config yaml file")
36
+ parser.add_argument("config_file",type=str)
37
+ args = parser.parse_args()
38
+ visualize_dataset(args)
39
+
40
+ if __name__=="__main__":
41
+ main()
src/simple_regression_colorization/model/base_model_interface.py CHANGED
@@ -9,6 +9,14 @@ class BaseModel(ABC):
9
  def train(self):
10
  pass
11
 
 
 
 
 
12
  @abstractmethod
13
  def predict(self,inputs):
14
  pass
 
 
 
 
 
9
  def train(self):
10
  pass
11
 
12
+ @abstractmethod
13
+ def evaluate(self):
14
+ pass
15
+
16
  @abstractmethod
17
  def predict(self,inputs):
18
  pass
19
+
20
+ @abstractmethod
21
+ def show_results(self):
22
+ pass
src/simple_regression_colorization/model/dataloaders.py CHANGED
@@ -1,14 +1,14 @@
1
  import tensorflow as tf
2
  from src.utils.data_utils import scale_L,scale_AB,rescale_AB,rescale_L
3
- from src.utils.config_loader import config
4
  from pathlib import Path
5
  from glob import glob
6
  import sklearn.model_selection
7
  from skimage.color import rgb2lab, lab2rgb
8
 
9
  def get_datasets():
10
- trainval_dir = config.PROCESSED_DATASET_DIR / Path("trainval/")
11
- test_dir = config.PROCESSED_DATASET_DIR / Path("test/")
12
 
13
  trainval_paths = glob(str(trainval_dir/Path("*")))
14
  test_paths = glob(str(test_dir/Path("*")))
 
1
  import tensorflow as tf
2
  from src.utils.data_utils import scale_L,scale_AB,rescale_AB,rescale_L
3
+ from src.utils.config_loader import config,constants
4
  from pathlib import Path
5
  from glob import glob
6
  import sklearn.model_selection
7
  from skimage.color import rgb2lab, lab2rgb
8
 
9
  def get_datasets():
10
+ trainval_dir = constants.PROCESSED_DATASET_DIR / Path("trainval/")
11
+ test_dir = constants.PROCESSED_DATASET_DIR / Path("test/")
12
 
13
  trainval_paths = glob(str(trainval_dir/Path("*")))
14
  test_paths = glob(str(test_dir/Path("*")))
src/simple_regression_colorization/model/models/model_v1.py CHANGED
@@ -1,30 +1,119 @@
1
  from src.simple_regression_colorization.model.base_model_interface import BaseModel
2
  from src.simple_regression_colorization.model.dataloaders import get_datasets
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  class Model(BaseModel):
5
 
6
- def __init__(self):
7
  # make model architecture
8
  # load weights (optional)
9
  # create dataset loaders
10
  # train
11
  # predict
12
  self.init_model()
13
- self.load_weights()
14
- self.prepare_data()
15
 
16
 
17
  def init_model(self):
18
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  def load_weights(self,path=None):
21
- pass
 
22
 
23
  def prepare_data(self):
24
  self.train_ds,self.val_ds,self.test_ds = get_datasets()
25
 
26
  def train(self):
27
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- def predict(self,inputs):
30
- pass
 
1
  from src.simple_regression_colorization.model.base_model_interface import BaseModel
2
  from src.simple_regression_colorization.model.dataloaders import get_datasets
3
+ from src.utils.config_loader import config
4
+ from src.utils.data_utils import scale_L,scale_AB,rescale_AB,rescale_L,see_batch
5
+ from skimage.color import lab2rgb
6
+ import tensorflow as tf
7
+ from tensorflow.keras import layers,Model as keras_Model,Sequential
8
+ import numpy as np
9
+
10
+ def down(filters,kernel_size,apply_batch_normalization=True):
11
+ down = Sequential()
12
+ down.add(layers.Conv2D(filters,kernel_size,padding="same",strides=2))
13
+ if apply_batch_normalization:
14
+ down.add(layers.BatchNormalization())
15
+ down.add(layers.LeakyReLU())
16
+ return down
17
+
18
+ def up(filters,kernel_size,dropout=False):
19
+ upsample = Sequential()
20
+ upsample.add(layers.Conv2DTranspose(filters,kernel_size,padding="same",strides=2))
21
+ if dropout:
22
+ upsample.add(layers.Dropout(dropout))
23
+ upsample.add(layers.LeakyReLU())
24
+ return upsample
25
 
26
  class Model(BaseModel):
27
 
28
+ def __init__(self,path=None):
29
  # make model architecture
30
  # load weights (optional)
31
  # create dataset loaders
32
  # train
33
  # predict
34
  self.init_model()
35
+ self.load_weights(path)
 
36
 
37
 
38
  def init_model(self):
39
+ x = layers.Input([config.image_size,config.image_size,1])
40
+ d1 = down(128,(3,3),False)(x)
41
+ d2 = down(128,(3,3),False)(d1)
42
+ d3 = down(256,(3,3),True)(d2)
43
+ d4 = down(512,(3,3),True)(d3)
44
+ d5 = down(512,(3,3),True)(d4)
45
+
46
+ u1 = up(512,(3,3))(d5)
47
+ u1 = layers.concatenate([u1,d4])
48
+ u2 = up(256,(3,3))(u1)
49
+ u2 = layers.concatenate([u2,d3])
50
+ u3 = up(128,(3,3))(u2)
51
+ u3 = layers.concatenate([u3,d2])
52
+ u4 = up(128,(3,3))(u3)
53
+ u4 = layers.concatenate([u4,d1])
54
+ u5 = up(64,(3,3))(u4)
55
+ u5 = layers.concatenate([u5,x])
56
+
57
+ y = layers.Conv2D(2,(2,2),strides = 1, padding = 'same',activation="tanh")(u5)
58
+
59
+ self.model = keras_Model(x,y,name="UNet")
60
+
61
 
62
  def load_weights(self,path=None):
63
+ if path:
64
+ self.model.load_weights(path)
65
 
66
  def prepare_data(self):
67
  self.train_ds,self.val_ds,self.test_ds = get_datasets()
68
 
69
  def train(self):
70
+
71
+ self.prepare_data()
72
+ self.model.compile(optimizer="adam",loss="mse",metrics=["mae","acc"])
73
+ self.history = self.model.fit(self.train_ds,
74
+ validation_data=self.val_ds,
75
+ epochs=config.epochs)
76
+
77
+ def save(self,model_path):
78
+ self.model.save_weights(model_path)
79
+
80
+ def predict(self,L_batch):
81
+ L_batch = scale_L(L_batch)
82
+ AB_batch = self.model.predict(L_batch,verbose=0)
83
+ return rescale_AB(AB_batch)
84
+
85
+ def evaluate(self):
86
+ train_metrics = self.model.evaluate(self.train_ds)
87
+ val_metrics = self.model.evaluate(self.val_ds)
88
+ test_metrics = self.model.evaluate(self.test_ds)
89
+
90
+ return {
91
+ "train": train_metrics,
92
+ "val": val_metrics,
93
+ "test": test_metrics,
94
+ }
95
+
96
+ def predict_colors(self,L_batch):
97
+ AB_batch = self.predict(L_batch)
98
+ colored_batch = np.concatenate([L_batch,rescale_AB(AB_batch)],axis=-1)
99
+ colored_batch = lab2rgb(colored_batch) * 255
100
+ return colored_batch
101
+
102
+ def show_results(self):
103
+ self.prepare_data()
104
+
105
+ L_batch,AB_batch = next(iter(self.train_ds))
106
+ L_batch = L_batch.numpy()
107
+ AB_pred = self.model.predict(L_batch,verbose=0)
108
+ see_batch(L_batch,AB_pred,title="Train dataset Results")
109
+
110
+ L_batch,AB_batch = next(iter(self.val_ds))
111
+ L_batch = L_batch.numpy()
112
+ AB_pred = self.model.predict(L_batch,verbose=0)
113
+ see_batch(L_batch,AB_pred,title="Val dataset Results")
114
+
115
+ L_batch,AB_batch = next(iter(self.test_ds))
116
+ L_batch = L_batch.numpy()
117
+ AB_pred = self.model.predict(L_batch,verbose=0)
118
+ see_batch(L_batch,AB_pred,title="Test dataset Results")
119
 
 
 
src/utils/data_utils.py CHANGED
@@ -75,3 +75,26 @@ def show_images_from_paths(image_paths:list[str],image_size=64,cols=4,row_size=5
75
  img = np.concatenate([BW,img],axis=1)
76
  plt.imshow(img.astype("uint8"))
77
  plt.show()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  img = np.concatenate([BW,img],axis=1)
76
  plt.imshow(img.astype("uint8"))
77
  plt.show()
78
+
79
+
80
+ def see_batch(L_batch,AB_batch,show_L=False,cols=4,row_size=5,col_size=5,title=None):
81
+ n = L_batch.shape[0]
82
+ rows = math.ceil(n/cols)
83
+ fig = plt.figure(figsize=(col_size*cols,row_size*rows))
84
+ if title:
85
+ plt.title(title)
86
+ plt.axis("off")
87
+
88
+ for i in range(n):
89
+ fig.add_subplot(rows,cols,i+1)
90
+ L,AB = L_batch[i],AB_batch[i]
91
+ L,AB = rescale_L(L), rescale_AB(AB)
92
+ # print(L.shape,AB.shape)
93
+ img = np.concatenate([L,AB],axis=-1)
94
+ img = cv2.cvtColor(img,cv2.COLOR_LAB2RGB)*255
95
+ # print(img.min(),img.max())
96
+ if show_L:
97
+ L = np.tile(L,(1,1,3))/100*255
98
+ img = np.concatenate([L,img],axis=1)
99
+ plt.imshow(img.astype("uint8"))
100
+ plt.show()