Spaces:
Runtime error
Runtime error
HAMIM-ML
commited on
Commit
·
cf8627d
1
Parent(s):
821ffc1
model building updated
Browse files- main.py +13 -0
- research/model_trainer.ipynb +0 -0
- src/imagecolorization/config/configuration.py +27 -1
- src/imagecolorization/conponents/model_building.py +4 -8
- src/imagecolorization/conponents/model_trainer.py +199 -0
- src/imagecolorization/entity/config_entity.py +16 -1
- src/imagecolorization/pipeline/stage_04_model_trainer.py +17 -0
main.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from src.imagecolorization.pipeline.stage01_data_ingestion import DataIngestionPipeline
|
2 |
from src.imagecolorization.pipeline.stage02_data_transformation import DataTransformationPipeline
|
3 |
from src.imagecolorization.pipeline.stage_03_model_building import ModelBuildingPipeline
|
|
|
4 |
from src.imagecolorization.logging import logger
|
5 |
|
6 |
STAGE_NAME = 'Data Ingestion Config'
|
@@ -33,6 +34,18 @@ try:
|
|
33 |
model_building = ModelBuildingPipeline()
|
34 |
model_building.main()
|
35 |
logger.info(f">>>>>> stage {STAGE_NAME} completed <<<<<<\n\nx==========x")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
except Exception as e:
|
37 |
logger.exception(e)
|
38 |
raise e
|
|
|
1 |
from src.imagecolorization.pipeline.stage01_data_ingestion import DataIngestionPipeline
|
2 |
from src.imagecolorization.pipeline.stage02_data_transformation import DataTransformationPipeline
|
3 |
from src.imagecolorization.pipeline.stage_03_model_building import ModelBuildingPipeline
|
4 |
+
from src.imagecolorization.pipeline.stage_04_model_trainer import ModelTrainerPipeline
|
5 |
from src.imagecolorization.logging import logger
|
6 |
|
7 |
STAGE_NAME = 'Data Ingestion Config'
|
|
|
34 |
model_building = ModelBuildingPipeline()
|
35 |
model_building.main()
|
36 |
logger.info(f">>>>>> stage {STAGE_NAME} completed <<<<<<\n\nx==========x")
|
37 |
+
except Exception as e:
|
38 |
+
logger.exception(e)
|
39 |
+
raise e
|
40 |
+
|
41 |
+
|
42 |
+
STAGE_NAME = 'Model Training Config'
|
43 |
+
|
44 |
+
try:
|
45 |
+
logger.info(f">>>>>> stage {STAGE_NAME} started <<<<<<")
|
46 |
+
model_trianer = ModelTrainerPipeline()
|
47 |
+
model_trianer.main()
|
48 |
+
logger.info(f">>>>>> stage {STAGE_NAME} completed <<<<<<\n\nx==========x")
|
49 |
except Exception as e:
|
50 |
logger.exception(e)
|
51 |
raise e
|
research/model_trainer.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/imagecolorization/config/configuration.py
CHANGED
@@ -2,7 +2,8 @@ from src.imagecolorization.constants import *
|
|
2 |
from src.imagecolorization.utils.common import read_yaml, create_directories
|
3 |
from src.imagecolorization.entity.config_entity import (DataIngestionConfig,
|
4 |
DataTransformationConfig,
|
5 |
-
ModelBuildingConfig
|
|
|
6 |
class ConfigurationManager:
|
7 |
def __init__(
|
8 |
self,
|
@@ -65,6 +66,31 @@ class ConfigurationManager:
|
|
65 |
IN_CHANNELS=params.IN_CHANNELS
|
66 |
)
|
67 |
return model_building_config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
|
70 |
|
|
|
2 |
from src.imagecolorization.utils.common import read_yaml, create_directories
|
3 |
from src.imagecolorization.entity.config_entity import (DataIngestionConfig,
|
4 |
DataTransformationConfig,
|
5 |
+
ModelBuildingConfig,
|
6 |
+
ModelTrainerConfig)
|
7 |
class ConfigurationManager:
|
8 |
def __init__(
|
9 |
self,
|
|
|
66 |
IN_CHANNELS=params.IN_CHANNELS
|
67 |
)
|
68 |
return model_building_config
|
69 |
+
|
70 |
+
|
71 |
+
def get_model_trainer_config(self) -> ModelTrainerConfig:
|
72 |
+
config = self.config.model_trainer
|
73 |
+
params = self.params
|
74 |
+
|
75 |
+
create_directories([config.root_dir])
|
76 |
+
|
77 |
+
# Convert LEARNING_RATE to float explicitly
|
78 |
+
learning_rate = float(params.LEARNING_RATE)
|
79 |
+
|
80 |
+
model_trainer_config = ModelTrainerConfig(
|
81 |
+
root_dir=config.root_dir,
|
82 |
+
test_data_path=config.test_data_path,
|
83 |
+
train_data_path=config.train_data_path,
|
84 |
+
LEARNING_RATE=learning_rate, # Use the converted float value
|
85 |
+
LAMBDA_RECON=params.LAMBDA_RECON,
|
86 |
+
DISPLAY_STEP=params.DISPLAY_STEP,
|
87 |
+
IMAGE_SIZE=params.IMAGE_SIZE,
|
88 |
+
INPUT_CHANNELS=params.INPUT_CHANNELS,
|
89 |
+
OUTPUT_CHANNELS=params.OUTPUT_CHANNELS,
|
90 |
+
EPOCH=params.EPOCH,
|
91 |
+
BATCH_SIZE= params.BATCH_SIZE
|
92 |
+
)
|
93 |
+
return model_trainer_config
|
94 |
|
95 |
|
96 |
|
src/imagecolorization/conponents/model_building.py
CHANGED
@@ -2,9 +2,7 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
from pathlib import Path
|
4 |
from torchsummary import summary
|
5 |
-
import
|
6 |
-
from src.imagecolorization.config.configuration import ModelBuildingConfig
|
7 |
-
|
8 |
class ResBlock(nn.Module):
|
9 |
def __init__(self, in_channles, out_channels, stride = 1, kerenl_size = 3, padding = 1, bias = False):
|
10 |
super().__init__()
|
@@ -112,7 +110,9 @@ class Critic(nn.Module):
|
|
112 |
return output
|
113 |
|
114 |
|
115 |
-
|
|
|
|
|
116 |
|
117 |
class ModelBuilding:
|
118 |
def __init__(self, config: ModelBuildingConfig):
|
@@ -162,7 +162,3 @@ class ModelBuilding:
|
|
162 |
self.save_model(generator, "generator.pth")
|
163 |
self.save_model(critic, "critic.pth")
|
164 |
return generator, critic
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
|
|
2 |
import torch.nn as nn
|
3 |
from pathlib import Path
|
4 |
from torchsummary import summary
|
5 |
+
from src.imagecolorization.entity.config_entity import ModelBuildingConfig
|
|
|
|
|
6 |
class ResBlock(nn.Module):
|
7 |
def __init__(self, in_channles, out_channels, stride = 1, kerenl_size = 3, padding = 1, bias = False):
|
8 |
super().__init__()
|
|
|
110 |
return output
|
111 |
|
112 |
|
113 |
+
from torchsummary import summary
|
114 |
+
import torch
|
115 |
+
import os
|
116 |
|
117 |
class ModelBuilding:
|
118 |
def __init__(self, config: ModelBuildingConfig):
|
|
|
162 |
self.save_model(generator, "generator.pth")
|
163 |
self.save_model(critic, "critic.pth")
|
164 |
return generator, critic
|
|
|
|
|
|
|
|
src/imagecolorization/conponents/model_trainer.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from skimage.color import rgb2lab, lab2rgb
|
4 |
+
import torch
|
5 |
+
from torch import nn, optim
|
6 |
+
from torchvision import transforms
|
7 |
+
from torch.utils.data import Dataset, DataLoader
|
8 |
+
from torch.autograd import Variable
|
9 |
+
from torchvision import models
|
10 |
+
from torch.nn import functional as F
|
11 |
+
import torch.utils.data
|
12 |
+
from torchvision.models.inception import inception_v3
|
13 |
+
from scipy.stats import entropy
|
14 |
+
import pytorch_lightning as pl
|
15 |
+
from torchsummary import summary
|
16 |
+
from src.imagecolorization.conponents.model_building import Generator, Critic
|
17 |
+
from src.imagecolorization.conponents.data_tranformation import ImageColorizationDataset
|
18 |
+
from src.imagecolorization.logging import logger
|
19 |
+
import gc
|
20 |
+
from src.imagecolorization.entity.config_entity import ModelTrainerConfig
|
21 |
+
|
22 |
+
def lab_to_rgb(L, ab):
|
23 |
+
L = L * 100
|
24 |
+
ab = (ab - 0.5) * 128 * 2
|
25 |
+
Lab = torch.cat([L, ab], dim = 2).numpy()
|
26 |
+
rgb_img = []
|
27 |
+
for img in Lab:
|
28 |
+
img_rgb = lab2rgb(img)
|
29 |
+
rgb_img.append(img_rgb)
|
30 |
+
|
31 |
+
return np.stack(rgb_img, axis = 0)
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
import matplotlib.pyplot as plt
|
36 |
+
def display_progress(cond, real, fake, current_epoch = 0, figsize=(20,15)):
|
37 |
+
"""
|
38 |
+
Save cond, real (original) and generated (fake)
|
39 |
+
images in one panel
|
40 |
+
"""
|
41 |
+
cond = cond.detach().cpu().permute(1, 2, 0)
|
42 |
+
real = real.detach().cpu().permute(1, 2, 0)
|
43 |
+
fake = fake.detach().cpu().permute(1, 2, 0)
|
44 |
+
|
45 |
+
images = [cond, real, fake]
|
46 |
+
titles = ['input','real','generated']
|
47 |
+
print(f'Epoch: {current_epoch}')
|
48 |
+
fig, ax = plt.subplots(1, 3, figsize=figsize)
|
49 |
+
for idx,img in enumerate(images):
|
50 |
+
if idx == 0:
|
51 |
+
ab = torch.zeros((224,224,2))
|
52 |
+
img = torch.cat([images[0]* 100, ab], dim=2).numpy()
|
53 |
+
imgan = lab2rgb(img)
|
54 |
+
else:
|
55 |
+
imgan = lab_to_rgb(images[0],img)
|
56 |
+
ax[idx].imshow(imgan)
|
57 |
+
ax[idx].axis("off")
|
58 |
+
for idx, title in enumerate(titles):
|
59 |
+
ax[idx].set_title('{}'.format(title))
|
60 |
+
plt.show()
|
61 |
+
|
62 |
+
|
63 |
+
class CWGAN(pl.LightningModule):
|
64 |
+
def __init__(self, in_channels, out_channels, learning_rate=0.0002, lambda_recon=100, display_step=10, lambda_gp=10, lambda_r1=10):
|
65 |
+
super().__init__()
|
66 |
+
self.save_hyperparameters()
|
67 |
+
self.display_step = display_step
|
68 |
+
self.generator = Generator(in_channels, out_channels)
|
69 |
+
self.critic = Critic(in_channels + out_channels)
|
70 |
+
self.lambda_recon = lambda_recon
|
71 |
+
self.lambda_gp = lambda_gp
|
72 |
+
self.lambda_r1 = lambda_r1
|
73 |
+
self.recon_criterion = nn.L1Loss()
|
74 |
+
self.generator_losses, self.critic_losses = [], []
|
75 |
+
self.automatic_optimization = False # Disable automatic optimization
|
76 |
+
|
77 |
+
def configure_optimizers(self):
|
78 |
+
optimizer_G = optim.Adam(self.generator.parameters(), lr=self.hparams.learning_rate, betas=(0.5, 0.9))
|
79 |
+
optimizer_C = optim.Adam(self.critic.parameters(), lr=self.hparams.learning_rate, betas=(0.5, 0.9))
|
80 |
+
return [optimizer_C, optimizer_G]
|
81 |
+
|
82 |
+
def generator_step(self, real_images, conditioned_images, optimizer_G):
|
83 |
+
optimizer_G.zero_grad()
|
84 |
+
fake_images = self.generator(conditioned_images)
|
85 |
+
recon_loss = self.recon_criterion(fake_images, real_images)
|
86 |
+
recon_loss.backward()
|
87 |
+
optimizer_G.step()
|
88 |
+
self.generator_losses.append(recon_loss.item())
|
89 |
+
|
90 |
+
def critic_step(self, real_images, conditioned_images, optimizer_C):
|
91 |
+
optimizer_C.zero_grad()
|
92 |
+
fake_images = self.generator(conditioned_images)
|
93 |
+
|
94 |
+
# Separate L and ab channels
|
95 |
+
l_real = conditioned_images
|
96 |
+
ab_real = real_images
|
97 |
+
ab_fake = fake_images
|
98 |
+
|
99 |
+
# Compute logits
|
100 |
+
fake_logits = self.critic(ab_fake, l_real) # Pass two arguments
|
101 |
+
real_logits = self.critic(ab_real, l_real) # Pass two arguments
|
102 |
+
|
103 |
+
# Compute the loss for the critic
|
104 |
+
loss_C = real_logits.mean() - fake_logits.mean()
|
105 |
+
|
106 |
+
# Compute the gradient penalty
|
107 |
+
alpha = torch.rand(real_images.size(0), 1, 1, 1, requires_grad=True).to(real_images.device)
|
108 |
+
interpolated = (alpha * ab_real + (1 - alpha) * ab_fake.detach()).requires_grad_(True)
|
109 |
+
interpolated_logits = self.critic(interpolated, l_real)
|
110 |
+
|
111 |
+
gradients = torch.autograd.grad(outputs=interpolated_logits, inputs=interpolated,
|
112 |
+
grad_outputs=torch.ones_like(interpolated_logits), create_graph=True, retain_graph=True)[0]
|
113 |
+
gradients = gradients.view(len(gradients), -1)
|
114 |
+
gradients_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
|
115 |
+
loss_C += self.lambda_gp * gradients_penalty
|
116 |
+
|
117 |
+
# Compute the R1 regularization loss
|
118 |
+
r1_reg = gradients.pow(2).sum(1).mean()
|
119 |
+
loss_C += self.lambda_r1 * r1_reg
|
120 |
+
|
121 |
+
# Backpropagation
|
122 |
+
loss_C.backward()
|
123 |
+
optimizer_C.step()
|
124 |
+
self.critic_losses.append(loss_C.item())
|
125 |
+
|
126 |
+
def training_step(self, batch, batch_idx):
|
127 |
+
optimizer_C, optimizer_G = self.optimizers()
|
128 |
+
real_images, conditioned_images = batch
|
129 |
+
|
130 |
+
self.critic_step(real_images, conditioned_images, optimizer_C)
|
131 |
+
self.generator_step(real_images, conditioned_images, optimizer_G)
|
132 |
+
|
133 |
+
if self.current_epoch % self.display_step == 0 and batch_idx == 0:
|
134 |
+
with torch.no_grad():
|
135 |
+
fake_images = self.generator(conditioned_images)
|
136 |
+
display_progress(conditioned_images[0], real_images[0], fake_images[0], self.current_epoch)
|
137 |
+
|
138 |
+
def on_epoch_end(self):
|
139 |
+
gc.collect()
|
140 |
+
torch.cuda.empty_cache()
|
141 |
+
|
142 |
+
def forward(self, x):
|
143 |
+
return self.generator(x)
|
144 |
+
|
145 |
+
|
146 |
+
import os
|
147 |
+
class ModelTrainer:
|
148 |
+
def __init__(self, config : ModelTrainerConfig):
|
149 |
+
self.config = config
|
150 |
+
|
151 |
+
def load_datasets(self):
|
152 |
+
self.train_dataset = torch.load(self.config.train_data_path)
|
153 |
+
self.test_dataset = torch.load(self.config.test_data_path)
|
154 |
+
|
155 |
+
def create_dataloaders(self):
|
156 |
+
self.train_dataloader = DataLoader(
|
157 |
+
self.train_dataset, batch_size=self.config.BATCH_SIZE, shuffle=True
|
158 |
+
)
|
159 |
+
self.test_dataloader = DataLoader(
|
160 |
+
self.test_dataset, batch_size=self.config.BATCH_SIZE, shuffle=False
|
161 |
+
)
|
162 |
+
|
163 |
+
def initialize_model(self):
|
164 |
+
self.model = CWGAN(
|
165 |
+
in_channels=self.config.INPUT_CHANNELS,
|
166 |
+
out_channels=self.config.OUTPUT_CHANNELS,
|
167 |
+
learning_rate=self.config.LEARNING_RATE,
|
168 |
+
lambda_recon=self.config.LAMBDA_RECON,
|
169 |
+
display_step=self.config.DISPLAY_STEP,
|
170 |
+
lambda_gp=10, # Default value, you can make it configurable
|
171 |
+
lambda_r1=10 # Default value, you can make it configurable
|
172 |
+
)
|
173 |
+
|
174 |
+
def train_model(self):
|
175 |
+
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
176 |
+
dirpath=self.config.root_dir,
|
177 |
+
filename='cwgan-{epoch:02d}-{generator_loss:.2f}',
|
178 |
+
save_top_k=-1, # Save all checkpoints
|
179 |
+
verbose=True
|
180 |
+
)
|
181 |
+
|
182 |
+
trainer = pl.Trainer(
|
183 |
+
max_epochs=self.config.EPOCH,
|
184 |
+
callbacks=[checkpoint_callback],
|
185 |
+
|
186 |
+
)
|
187 |
+
|
188 |
+
trainer.fit(self.model, self.train_dataloader)
|
189 |
+
|
190 |
+
def save_model(self):
|
191 |
+
trained_model_dir = self.config.root_dir
|
192 |
+
os.makedirs(trained_model_dir, exist_ok=True)
|
193 |
+
|
194 |
+
generator_path = os.path.join(trained_model_dir, "cwgan_generator_final.pt")
|
195 |
+
critic_path = os.path.join(trained_model_dir, "cwgan_critic_final.pt")
|
196 |
+
|
197 |
+
torch.save(self.model.generator.state_dict(), generator_path)
|
198 |
+
torch.save(self.model.critic.state_dict(), critic_path)
|
199 |
+
logger.info(f"Final models saved at {trained_model_dir}")
|
src/imagecolorization/entity/config_entity.py
CHANGED
@@ -32,4 +32,19 @@ class ModelBuildingConfig:
|
|
32 |
KERNEL_SIZE_GENERATOR: int
|
33 |
INPUT_CHANNELS: int
|
34 |
OUTPUT_CHANNELS: int
|
35 |
-
IN_CHANNELS: int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
KERNEL_SIZE_GENERATOR: int
|
33 |
INPUT_CHANNELS: int
|
34 |
OUTPUT_CHANNELS: int
|
35 |
+
IN_CHANNELS: int
|
36 |
+
|
37 |
+
|
38 |
+
@dataclass(frozen=True)
|
39 |
+
class ModelTrainerConfig:
|
40 |
+
root_dir: Path
|
41 |
+
test_data_path: Path
|
42 |
+
train_data_path: Path
|
43 |
+
LEARNING_RATE: float
|
44 |
+
LAMBDA_RECON: int
|
45 |
+
DISPLAY_STEP: int
|
46 |
+
IMAGE_SIZE: list
|
47 |
+
INPUT_CHANNELS: int
|
48 |
+
OUTPUT_CHANNELS: int
|
49 |
+
EPOCH: int
|
50 |
+
BATCH_SIZE : int
|
src/imagecolorization/pipeline/stage_04_model_trainer.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.imagecolorization.conponents.model_trainer import ModelTrainer
|
2 |
+
from src.imagecolorization.config.configuration import ConfigurationManager
|
3 |
+
|
4 |
+
|
5 |
+
class ModelTrainerPipeline:
|
6 |
+
def __init__(self):
|
7 |
+
pass
|
8 |
+
|
9 |
+
def main(self):
|
10 |
+
config_manager = ConfigurationManager()
|
11 |
+
model_trainer_config = config_manager.get_model_trainer_config()
|
12 |
+
model_trainer = ModelTrainer(config=model_trainer_config)
|
13 |
+
model_trainer.load_datasets()
|
14 |
+
model_trainer.create_dataloaders()
|
15 |
+
model_trainer.initialize_model()
|
16 |
+
model_trainer.train_model()
|
17 |
+
model_trainer.save_model()
|