HAMIM-ML commited on
Commit
cf8627d
·
1 Parent(s): 821ffc1

model building updated

Browse files
main.py CHANGED
@@ -1,6 +1,7 @@
1
  from src.imagecolorization.pipeline.stage01_data_ingestion import DataIngestionPipeline
2
  from src.imagecolorization.pipeline.stage02_data_transformation import DataTransformationPipeline
3
  from src.imagecolorization.pipeline.stage_03_model_building import ModelBuildingPipeline
 
4
  from src.imagecolorization.logging import logger
5
 
6
  STAGE_NAME = 'Data Ingestion Config'
@@ -33,6 +34,18 @@ try:
33
  model_building = ModelBuildingPipeline()
34
  model_building.main()
35
  logger.info(f">>>>>> stage {STAGE_NAME} completed <<<<<<\n\nx==========x")
 
 
 
 
 
 
 
 
 
 
 
 
36
  except Exception as e:
37
  logger.exception(e)
38
  raise e
 
1
  from src.imagecolorization.pipeline.stage01_data_ingestion import DataIngestionPipeline
2
  from src.imagecolorization.pipeline.stage02_data_transformation import DataTransformationPipeline
3
  from src.imagecolorization.pipeline.stage_03_model_building import ModelBuildingPipeline
4
+ from src.imagecolorization.pipeline.stage_04_model_trainer import ModelTrainerPipeline
5
  from src.imagecolorization.logging import logger
6
 
7
  STAGE_NAME = 'Data Ingestion Config'
 
34
  model_building = ModelBuildingPipeline()
35
  model_building.main()
36
  logger.info(f">>>>>> stage {STAGE_NAME} completed <<<<<<\n\nx==========x")
37
+ except Exception as e:
38
+ logger.exception(e)
39
+ raise e
40
+
41
+
42
+ STAGE_NAME = 'Model Training Config'
43
+
44
+ try:
45
+ logger.info(f">>>>>> stage {STAGE_NAME} started <<<<<<")
46
+ model_trianer = ModelTrainerPipeline()
47
+ model_trianer.main()
48
+ logger.info(f">>>>>> stage {STAGE_NAME} completed <<<<<<\n\nx==========x")
49
  except Exception as e:
50
  logger.exception(e)
51
  raise e
research/model_trainer.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/imagecolorization/config/configuration.py CHANGED
@@ -2,7 +2,8 @@ from src.imagecolorization.constants import *
2
  from src.imagecolorization.utils.common import read_yaml, create_directories
3
  from src.imagecolorization.entity.config_entity import (DataIngestionConfig,
4
  DataTransformationConfig,
5
- ModelBuildingConfig)
 
6
  class ConfigurationManager:
7
  def __init__(
8
  self,
@@ -65,6 +66,31 @@ class ConfigurationManager:
65
  IN_CHANNELS=params.IN_CHANNELS
66
  )
67
  return model_building_config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
 
70
 
 
2
  from src.imagecolorization.utils.common import read_yaml, create_directories
3
  from src.imagecolorization.entity.config_entity import (DataIngestionConfig,
4
  DataTransformationConfig,
5
+ ModelBuildingConfig,
6
+ ModelTrainerConfig)
7
  class ConfigurationManager:
8
  def __init__(
9
  self,
 
66
  IN_CHANNELS=params.IN_CHANNELS
67
  )
68
  return model_building_config
69
+
70
+
71
+ def get_model_trainer_config(self) -> ModelTrainerConfig:
72
+ config = self.config.model_trainer
73
+ params = self.params
74
+
75
+ create_directories([config.root_dir])
76
+
77
+ # Convert LEARNING_RATE to float explicitly
78
+ learning_rate = float(params.LEARNING_RATE)
79
+
80
+ model_trainer_config = ModelTrainerConfig(
81
+ root_dir=config.root_dir,
82
+ test_data_path=config.test_data_path,
83
+ train_data_path=config.train_data_path,
84
+ LEARNING_RATE=learning_rate, # Use the converted float value
85
+ LAMBDA_RECON=params.LAMBDA_RECON,
86
+ DISPLAY_STEP=params.DISPLAY_STEP,
87
+ IMAGE_SIZE=params.IMAGE_SIZE,
88
+ INPUT_CHANNELS=params.INPUT_CHANNELS,
89
+ OUTPUT_CHANNELS=params.OUTPUT_CHANNELS,
90
+ EPOCH=params.EPOCH,
91
+ BATCH_SIZE= params.BATCH_SIZE
92
+ )
93
+ return model_trainer_config
94
 
95
 
96
 
src/imagecolorization/conponents/model_building.py CHANGED
@@ -2,9 +2,7 @@ import torch
2
  import torch.nn as nn
3
  from pathlib import Path
4
  from torchsummary import summary
5
- import os
6
- from src.imagecolorization.config.configuration import ModelBuildingConfig
7
-
8
  class ResBlock(nn.Module):
9
  def __init__(self, in_channles, out_channels, stride = 1, kerenl_size = 3, padding = 1, bias = False):
10
  super().__init__()
@@ -112,7 +110,9 @@ class Critic(nn.Module):
112
  return output
113
 
114
 
115
-
 
 
116
 
117
  class ModelBuilding:
118
  def __init__(self, config: ModelBuildingConfig):
@@ -162,7 +162,3 @@ class ModelBuilding:
162
  self.save_model(generator, "generator.pth")
163
  self.save_model(critic, "critic.pth")
164
  return generator, critic
165
-
166
-
167
-
168
-
 
2
  import torch.nn as nn
3
  from pathlib import Path
4
  from torchsummary import summary
5
+ from src.imagecolorization.entity.config_entity import ModelBuildingConfig
 
 
6
  class ResBlock(nn.Module):
7
  def __init__(self, in_channles, out_channels, stride = 1, kerenl_size = 3, padding = 1, bias = False):
8
  super().__init__()
 
110
  return output
111
 
112
 
113
+ from torchsummary import summary
114
+ import torch
115
+ import os
116
 
117
  class ModelBuilding:
118
  def __init__(self, config: ModelBuildingConfig):
 
162
  self.save_model(generator, "generator.pth")
163
  self.save_model(critic, "critic.pth")
164
  return generator, critic
 
 
 
 
src/imagecolorization/conponents/model_trainer.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from skimage.color import rgb2lab, lab2rgb
4
+ import torch
5
+ from torch import nn, optim
6
+ from torchvision import transforms
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from torch.autograd import Variable
9
+ from torchvision import models
10
+ from torch.nn import functional as F
11
+ import torch.utils.data
12
+ from torchvision.models.inception import inception_v3
13
+ from scipy.stats import entropy
14
+ import pytorch_lightning as pl
15
+ from torchsummary import summary
16
+ from src.imagecolorization.conponents.model_building import Generator, Critic
17
+ from src.imagecolorization.conponents.data_tranformation import ImageColorizationDataset
18
+ from src.imagecolorization.logging import logger
19
+ import gc
20
+ from src.imagecolorization.entity.config_entity import ModelTrainerConfig
21
+
22
+ def lab_to_rgb(L, ab):
23
+ L = L * 100
24
+ ab = (ab - 0.5) * 128 * 2
25
+ Lab = torch.cat([L, ab], dim = 2).numpy()
26
+ rgb_img = []
27
+ for img in Lab:
28
+ img_rgb = lab2rgb(img)
29
+ rgb_img.append(img_rgb)
30
+
31
+ return np.stack(rgb_img, axis = 0)
32
+
33
+
34
+
35
+ import matplotlib.pyplot as plt
36
+ def display_progress(cond, real, fake, current_epoch = 0, figsize=(20,15)):
37
+ """
38
+ Save cond, real (original) and generated (fake)
39
+ images in one panel
40
+ """
41
+ cond = cond.detach().cpu().permute(1, 2, 0)
42
+ real = real.detach().cpu().permute(1, 2, 0)
43
+ fake = fake.detach().cpu().permute(1, 2, 0)
44
+
45
+ images = [cond, real, fake]
46
+ titles = ['input','real','generated']
47
+ print(f'Epoch: {current_epoch}')
48
+ fig, ax = plt.subplots(1, 3, figsize=figsize)
49
+ for idx,img in enumerate(images):
50
+ if idx == 0:
51
+ ab = torch.zeros((224,224,2))
52
+ img = torch.cat([images[0]* 100, ab], dim=2).numpy()
53
+ imgan = lab2rgb(img)
54
+ else:
55
+ imgan = lab_to_rgb(images[0],img)
56
+ ax[idx].imshow(imgan)
57
+ ax[idx].axis("off")
58
+ for idx, title in enumerate(titles):
59
+ ax[idx].set_title('{}'.format(title))
60
+ plt.show()
61
+
62
+
63
+ class CWGAN(pl.LightningModule):
64
+ def __init__(self, in_channels, out_channels, learning_rate=0.0002, lambda_recon=100, display_step=10, lambda_gp=10, lambda_r1=10):
65
+ super().__init__()
66
+ self.save_hyperparameters()
67
+ self.display_step = display_step
68
+ self.generator = Generator(in_channels, out_channels)
69
+ self.critic = Critic(in_channels + out_channels)
70
+ self.lambda_recon = lambda_recon
71
+ self.lambda_gp = lambda_gp
72
+ self.lambda_r1 = lambda_r1
73
+ self.recon_criterion = nn.L1Loss()
74
+ self.generator_losses, self.critic_losses = [], []
75
+ self.automatic_optimization = False # Disable automatic optimization
76
+
77
+ def configure_optimizers(self):
78
+ optimizer_G = optim.Adam(self.generator.parameters(), lr=self.hparams.learning_rate, betas=(0.5, 0.9))
79
+ optimizer_C = optim.Adam(self.critic.parameters(), lr=self.hparams.learning_rate, betas=(0.5, 0.9))
80
+ return [optimizer_C, optimizer_G]
81
+
82
+ def generator_step(self, real_images, conditioned_images, optimizer_G):
83
+ optimizer_G.zero_grad()
84
+ fake_images = self.generator(conditioned_images)
85
+ recon_loss = self.recon_criterion(fake_images, real_images)
86
+ recon_loss.backward()
87
+ optimizer_G.step()
88
+ self.generator_losses.append(recon_loss.item())
89
+
90
+ def critic_step(self, real_images, conditioned_images, optimizer_C):
91
+ optimizer_C.zero_grad()
92
+ fake_images = self.generator(conditioned_images)
93
+
94
+ # Separate L and ab channels
95
+ l_real = conditioned_images
96
+ ab_real = real_images
97
+ ab_fake = fake_images
98
+
99
+ # Compute logits
100
+ fake_logits = self.critic(ab_fake, l_real) # Pass two arguments
101
+ real_logits = self.critic(ab_real, l_real) # Pass two arguments
102
+
103
+ # Compute the loss for the critic
104
+ loss_C = real_logits.mean() - fake_logits.mean()
105
+
106
+ # Compute the gradient penalty
107
+ alpha = torch.rand(real_images.size(0), 1, 1, 1, requires_grad=True).to(real_images.device)
108
+ interpolated = (alpha * ab_real + (1 - alpha) * ab_fake.detach()).requires_grad_(True)
109
+ interpolated_logits = self.critic(interpolated, l_real)
110
+
111
+ gradients = torch.autograd.grad(outputs=interpolated_logits, inputs=interpolated,
112
+ grad_outputs=torch.ones_like(interpolated_logits), create_graph=True, retain_graph=True)[0]
113
+ gradients = gradients.view(len(gradients), -1)
114
+ gradients_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
115
+ loss_C += self.lambda_gp * gradients_penalty
116
+
117
+ # Compute the R1 regularization loss
118
+ r1_reg = gradients.pow(2).sum(1).mean()
119
+ loss_C += self.lambda_r1 * r1_reg
120
+
121
+ # Backpropagation
122
+ loss_C.backward()
123
+ optimizer_C.step()
124
+ self.critic_losses.append(loss_C.item())
125
+
126
+ def training_step(self, batch, batch_idx):
127
+ optimizer_C, optimizer_G = self.optimizers()
128
+ real_images, conditioned_images = batch
129
+
130
+ self.critic_step(real_images, conditioned_images, optimizer_C)
131
+ self.generator_step(real_images, conditioned_images, optimizer_G)
132
+
133
+ if self.current_epoch % self.display_step == 0 and batch_idx == 0:
134
+ with torch.no_grad():
135
+ fake_images = self.generator(conditioned_images)
136
+ display_progress(conditioned_images[0], real_images[0], fake_images[0], self.current_epoch)
137
+
138
+ def on_epoch_end(self):
139
+ gc.collect()
140
+ torch.cuda.empty_cache()
141
+
142
+ def forward(self, x):
143
+ return self.generator(x)
144
+
145
+
146
+ import os
147
+ class ModelTrainer:
148
+ def __init__(self, config : ModelTrainerConfig):
149
+ self.config = config
150
+
151
+ def load_datasets(self):
152
+ self.train_dataset = torch.load(self.config.train_data_path)
153
+ self.test_dataset = torch.load(self.config.test_data_path)
154
+
155
+ def create_dataloaders(self):
156
+ self.train_dataloader = DataLoader(
157
+ self.train_dataset, batch_size=self.config.BATCH_SIZE, shuffle=True
158
+ )
159
+ self.test_dataloader = DataLoader(
160
+ self.test_dataset, batch_size=self.config.BATCH_SIZE, shuffle=False
161
+ )
162
+
163
+ def initialize_model(self):
164
+ self.model = CWGAN(
165
+ in_channels=self.config.INPUT_CHANNELS,
166
+ out_channels=self.config.OUTPUT_CHANNELS,
167
+ learning_rate=self.config.LEARNING_RATE,
168
+ lambda_recon=self.config.LAMBDA_RECON,
169
+ display_step=self.config.DISPLAY_STEP,
170
+ lambda_gp=10, # Default value, you can make it configurable
171
+ lambda_r1=10 # Default value, you can make it configurable
172
+ )
173
+
174
+ def train_model(self):
175
+ checkpoint_callback = pl.callbacks.ModelCheckpoint(
176
+ dirpath=self.config.root_dir,
177
+ filename='cwgan-{epoch:02d}-{generator_loss:.2f}',
178
+ save_top_k=-1, # Save all checkpoints
179
+ verbose=True
180
+ )
181
+
182
+ trainer = pl.Trainer(
183
+ max_epochs=self.config.EPOCH,
184
+ callbacks=[checkpoint_callback],
185
+
186
+ )
187
+
188
+ trainer.fit(self.model, self.train_dataloader)
189
+
190
+ def save_model(self):
191
+ trained_model_dir = self.config.root_dir
192
+ os.makedirs(trained_model_dir, exist_ok=True)
193
+
194
+ generator_path = os.path.join(trained_model_dir, "cwgan_generator_final.pt")
195
+ critic_path = os.path.join(trained_model_dir, "cwgan_critic_final.pt")
196
+
197
+ torch.save(self.model.generator.state_dict(), generator_path)
198
+ torch.save(self.model.critic.state_dict(), critic_path)
199
+ logger.info(f"Final models saved at {trained_model_dir}")
src/imagecolorization/entity/config_entity.py CHANGED
@@ -32,4 +32,19 @@ class ModelBuildingConfig:
32
  KERNEL_SIZE_GENERATOR: int
33
  INPUT_CHANNELS: int
34
  OUTPUT_CHANNELS: int
35
- IN_CHANNELS: int
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  KERNEL_SIZE_GENERATOR: int
33
  INPUT_CHANNELS: int
34
  OUTPUT_CHANNELS: int
35
+ IN_CHANNELS: int
36
+
37
+
38
+ @dataclass(frozen=True)
39
+ class ModelTrainerConfig:
40
+ root_dir: Path
41
+ test_data_path: Path
42
+ train_data_path: Path
43
+ LEARNING_RATE: float
44
+ LAMBDA_RECON: int
45
+ DISPLAY_STEP: int
46
+ IMAGE_SIZE: list
47
+ INPUT_CHANNELS: int
48
+ OUTPUT_CHANNELS: int
49
+ EPOCH: int
50
+ BATCH_SIZE : int
src/imagecolorization/pipeline/stage_04_model_trainer.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.imagecolorization.conponents.model_trainer import ModelTrainer
2
+ from src.imagecolorization.config.configuration import ConfigurationManager
3
+
4
+
5
+ class ModelTrainerPipeline:
6
+ def __init__(self):
7
+ pass
8
+
9
+ def main(self):
10
+ config_manager = ConfigurationManager()
11
+ model_trainer_config = config_manager.get_model_trainer_config()
12
+ model_trainer = ModelTrainer(config=model_trainer_config)
13
+ model_trainer.load_datasets()
14
+ model_trainer.create_dataloaders()
15
+ model_trainer.initialize_model()
16
+ model_trainer.train_model()
17
+ model_trainer.save_model()