In [1]:
import os
os.chdir('../')

In [2]:
%pwd

'c:\\mlops project\\image-colorization-mlops'

In [3]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class DataTransformationConfig:
    root_dir : Path
    data_path_black : Path
    data_path_grey : Path
    BATCH_SIZE : int
    IMAGE_SIZE : list
    DATA_RANGE: int


In [4]:
from src.imagecolorization.constants import *
from src.imagecolorization.utils.common import read_yaml, create_directories

class ConfigurationManager:
    def __init__(
        self,
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH):

        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifacts_root])
        
        
    def get_data_transformation_config(self) -> DataTransformationConfig:
        config = self.config.data_transformation
        params = self.params
        
        data_transformation_cofig = DataTransformationConfig(
            root_dir= config.root_dir,
            data_path_black=config.data_path_black,
            data_path_grey=config.data_path_grey,
            BATCH_SIZE=params.BATCH_SIZE,
            IMAGE_SIZE=params.IMAGE_SIZE,
            DATA_RANGE=params.DATA_RANGE
        )
        
        return data_transformation_cofig

In [5]:
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import transforms

        
        
        
class ImageColorizationDataset:
    def __init__(self, dataset, image_size, transform = None):
        self.dataset = dataset
        self.transform = transform
        self.image_size = tuple(image_size)
        
    def __len__(self):
        return len(self.dataset[0])
    
    def __getitem__(self, idx):
        L  = np.array(self.dataset[0][idx]).reshape(self.image_size)
        L = transforms.ToTensor()(L)
        
        ab = np.array(self.dataset[1][idx])
        ab = transforms.ToTensor()(ab)
        
        return ab , L

In [12]:
from torch.utils.data import DataLoader
import gc
import os
import numpy as np
import torch
from src.imagecolorization.logging import logger

class DataTransformation:
    def __init__(self, config: DataTransformationConfig):
        self.config = config
        
    def load_data(self):
        ab_df = np.load(self.config.data_path_black)[:self.config.DATA_RANGE]
        l_df = np.load(self.config.data_path_grey)[:self.config.DATA_RANGE]
        dataset = (l_df, ab_df)
        gc.collect()
        return dataset
    
    def get_datasets(self, dataset):
        train_dataset = ImageColorizationDataset(
            dataset=dataset,
            image_size=self.config.IMAGE_SIZE
        )
        test_dataset = ImageColorizationDataset(
            dataset=dataset,
            image_size=self.config.IMAGE_SIZE
        )
    
        return train_dataset, test_dataset
        
        
    
    def save_datasets(self, train_dataset, test_dataset):
        # Ensure the directory exists
        os.makedirs(self.config.root_dir, exist_ok=True)

        train_dataset_path = os.path.join(self.config.root_dir, 'train_dataset.pt')
        test_dataset_path = os.path.join(self.config.root_dir, 'test_dataset.pt')

        try:
            # Save the datasets
            torch.save(train_dataset, train_dataset_path)
            torch.save(test_dataset, test_dataset_path)

            logger.info(f"Train dataset saved at: {train_dataset_path}")
            logger.info(f"Test dataset saved at: {test_dataset_path}")
        except Exception as e:
            logger.error(f"Error saving datasets: {str(e)}")
            raise e

In [13]:
try:
    config = ConfigurationManager()
    data_transformation_config = config.get_data_transformation_config()
    data_transformation = DataTransformation(config=data_transformation_config)
    
    # Load the dataset
    dataset = data_transformation.load_data()
    
    # Get the datasets using the loaded dataset
    train_dataset, test_dataset = data_transformation.get_datasets(dataset)
    
    # Perform any further operations (e.g., saving the dataset)
    data_transformation.save_datasets(train_dataset, test_dataset)
    
    # Print information about the train dataset
    print(f"Train dataset type: {type(train_dataset)}")
    print(f"Train dataset length: {len(train_dataset)}")
    print(f"First item in train dataset: {train_dataset[0]}")
    
except Exception as e:
    raise e

[2024-08-24 23:35:18,235: INFO: common: yaml file: config\config.yaml loaded successfully]
[2024-08-24 23:35:18,243: INFO: common: yaml file: params.yaml loaded successfully]
[2024-08-24 23:35:18,245: INFO: common: created directory at: artifacts]
[2024-08-24 23:35:34,541: INFO: 2411080742: Train dataset saved at: artifacts/data_transformation\train_dataset.pt]
[2024-08-24 23:35:34,552: INFO: 2411080742: Test dataset saved at: artifacts/data_transformation\test_dataset.pt]
Train dataset type: <class '__main__.ImageColorizationDataset'>
Train dataset length: 5000
First item in train dataset: (tensor([[[0.5059, 0.4941, 0.4941,  ..., 0.4902, 0.4863, 0.4863],
         [0.4980, 0.4941, 0.4941,  ..., 0.4980, 0.4902, 0.4824],
         [0.4980, 0.4980, 0.4980,  ..., 0.4902, 0.4941, 0.5020],
         ...,
         [0.4941, 0.4941, 0.4980,  ..., 0.4941, 0.4941, 0.4941],
         [0.4941, 0.4980, 0.4941,  ..., 0.4941, 0.4941, 0.4941],
         [0.4980, 0.4980, 0.5020,  ..., 0.4941, 0.4863, 0.4941

In [6]:
from torch.utils.data import Dataset, DataLoader


test_dataset = torch.load('artifacts/data_transformation/test_dataset.pt')
train_dataset = torch.load('artifacts/data_transformation/train_dataset.pt')


train_loader1 = DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=1
)
test_loader1 = DataLoader(
    test_dataset,
    shuffle=True,
    batch_size=1
)



  test_dataset = torch.load('artifacts/data_transformation/test_dataset.pt')
  train_dataset = torch.load('artifacts/data_transformation/train_dataset.pt')


In [7]:
# Print shapes for train dataloader
for batch in train_loader1:
    ab, L = batch
    print(f"Train loader batch - ab shape: {ab.shape}, L shape: {L.shape}")
    break  # We only need to check one batch

# Print shapes for test dataloader
for batch in test_loader1:
     ab, L = batch
     print(f"Test loader batch - ab shape: {ab.shape}, L shape: {L.shape}")
     break  # We only need to check one batch

Train loader batch - ab shape: torch.Size([1, 2, 224, 224]), L shape: torch.Size([1, 1, 224, 224])
Test loader batch - ab shape: torch.Size([1, 2, 224, 224]), L shape: torch.Size([1, 1, 224, 224])


In [8]:
from skimage.color import rgb2lab, lab2rgb
import gc
import matplotlib.pylab as plt

def lab_to_rgb(L, ab):
    L = L * 100
    ab = (ab - 0.5) * 128 * 2
    Lab = torch.cat([L, ab], dim = 2).numpy()
    rgb_img = []
    for img in Lab:
        img_rgb = lab2rgb(img)
        rgb_img.append(img_rgb)
        
    return np.stack(rgb_img, axis = 0)




In [9]:
def display_progress(cond, real, fake, current_epoch = 0, figsize=(20,15)):
    """
    Save cond, real (original) and generated (fake)
    images in one panel 
    """
    cond = cond.detach().cpu().permute(1, 2, 0)   
    real = real.detach().cpu().permute(1, 2, 0)
    fake = fake.detach().cpu().permute(1, 2, 0)
    
    images = [cond, real, fake]
    titles = ['input','real','generated']
    print(f'Epoch: {current_epoch}')
    fig, ax = plt.subplots(1, 3, figsize=figsize)
    for idx,img in enumerate(images):
        if idx == 0:
            ab = torch.zeros((224,224,2))
            img = torch.cat([images[0]* 100, ab], dim=2).numpy()
            imgan = lab2rgb(img)
        else:
            imgan = lab_to_rgb(images[0],img)
        ax[idx].imshow(imgan)
        ax[idx].axis("off")
    for idx, title in enumerate(titles):    
        ax[idx].set_title('{}'.format(title))
    plt.show()


In [13]:
import pytorch_lightning as pl
from src.imagecolorization.conponents.model_trainer import Generator,Critic
from torch import nn, optim

class CWGAN(pl.LightningModule):
    def __init__(self, in_channels, out_channels, learning_rate=0.0002, lambda_recon=100, display_step=10, lambda_gp=10, lambda_r1=10):
        super().__init__()
        self.save_hyperparameters()
        self.display_step = display_step
        self.generator = Generator(in_channels, out_channels)
        self.critic = Critic(in_channels + out_channels)  # Ensure Critic is initialized with the correct input channels
        self.lambda_recon = lambda_recon
        self.lambda_gp = lambda_gp
        self.lambda_r1 = lambda_r1
        self.recon_criterion = nn.L1Loss()
        self.generator_losses, self.critic_losses = [], []
        self.automatic_optimization = False
        
    def configure_optimizers(self):
        optimizer_G = optim.Adam(self.generator.parameters(), lr=self.hparams.learning_rate, betas=(0.5, 0.9))
        optimizer_C = optim.Adam(self.critic.parameters(), lr=self.hparams.learning_rate, betas=(0.5, 0.9))
        return [optimizer_C, optimizer_G]
    
    def generator_step(self, real_images, conditioned_images, optimizer_G):
        # WGAN has only a reconstruction loss
        optimizer_G.zero_grad()
        fake_images = self.generator(conditioned_images)
        recon_loss = self.recon_criterion(fake_images, real_images)
        recon_loss.backward()
        optimizer_G.step()
        self.generator_losses.append(recon_loss.item())
        
    def critic_step(self, real_images, conditioned_images, optimizer_C):
        optimizer_C.zero_grad()
        fake_images = self.generator(conditioned_images)
        fake_input = torch.cat((fake_images, conditioned_images), 1)
        real_input = torch.cat((real_images, conditioned_images), 1)
        fake_logits = self.critic(fake_input)
        real_logits = self.critic(real_input)

        # Compute the loss for the critic
        loss_C = real_logits.mean() - fake_logits.mean()

        # Compute the gradient penalty
        alpha = torch.rand(real_images.size(0), 1, 1, 1, requires_grad=True).to(real_images.device)
        interpolated = (alpha * real_images + (1 - alpha) * fake_images.detach()).requires_grad_(True)
        interpolated_logits = self.critic(interpolated, conditioned_images)
    
        gradients = torch.autograd.grad(outputs=interpolated_logits, inputs=interpolated,
                                    grad_outputs=torch.ones_like(interpolated_logits), create_graph=True, retain_graph=True)[0]
        gradients = gradients.view(len(gradients), -1)
        gradients_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        loss_C += self.lambda_gp * gradients_penalty
    
        # Compute the R1 regularization loss
        r1_reg = gradients.pow(2).sum(1).mean()
        loss_C += self.lambda_r1 * r1_reg

        # Backpropagation
        loss_C.backward()
        optimizer_C.step()
        self.critic_losses.append(loss_C.item())

    
    def training_step(self, batch, batch_idx):
        real, condition = batch
        optimizer_C, optimizer_G = self.optimizers()  # Access optimizers
    
        # Debugging shapes
        print(f"Real images shape: {real.shape}")
        print(f"Conditioned images shape: {condition.shape}")
    
        # Update the critic
        self.critic_step(real, condition, optimizer_C)
    
        # Update the generator
        self.generator_step(real, condition, optimizer_G)
    
    
        # Logging and saving models
        gen_mean = sum(self.generator_losses[-self.display_step:]) / self.display_step
        crit_mean = sum(self.critic_losses[-self.display_step:]) / self.display_step
        if self.current_epoch % self.display_step == 0 and batch_idx == 0:
            fake = self.generator(condition).detach()
            print(f"Epoch {self.current_epoch}: Generator loss: {gen_mean}, Critic loss: {crit_mean}")
            display_progress(condition[0], real[0], fake[0], self.current_epoch)
    
        # Save models every 10 epochs
        if (self.current_epoch + 1) % 10 == 0 and batch_idx == self.trainer.num_training_batches - 1:
            torch.save(self.generator.state_dict(), f"/kaggle/working/cwgan_generator_epoch_{self.current_epoch+1}.pt")
            torch.save(self.critic.state_dict(), f"/kaggle/working/cwgan_critic_epoch_{self.current_epoch+1}.pt")
            print(f"Saved models at epoch {self.current_epoch+1}")

        # Final save at epoch 150
        if self.current_epoch == 149 and batch_idx == self.trainer.num_training_batches - 1:
            torch.save(self.generator.state_dict(), "cwgan_generator_final.pt")
            torch.save(self.critic.state_dict(), "cwgan_critic_final.pt")
            print("Saved final models at epoch 150")
    

In [14]:
from torch import nn, optim

gc.collect()
cwgan = CWGAN(in_channels = 1, out_channels = 2 ,learning_rate=2e-4, lambda_recon=100, display_step=10)

In [15]:
trainer = pl.Trainer(max_epochs=150)
trainer.fit(cwgan, train_loader1)

[2024-08-25 00:04:53,828: INFO: rank_zero: GPU available: True (cuda), used: True]
[2024-08-25 00:04:53,829: INFO: rank_zero: TPU available: False, using: 0 TPU cores]
[2024-08-25 00:04:53,830: INFO: rank_zero: IPU available: False, using: 0 IPUs]
[2024-08-25 00:04:53,831: INFO: rank_zero: HPU available: False, using: 0 HPUs]


[2024-08-25 00:04:54,549: INFO: cuda: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]]
[2024-08-25 00:04:54,552: INFO: model_summary: 
  | Name            | Type      | Params
----------------------------------------------
0 | generator       | Generator | 8.2 M 
1 | critic          | Critic    | 2.8 M 
2 | recon_criterion | L1Loss    | 0     
----------------------------------------------
11.0 M    Trainable params
0         Non-trainable params
11.0 M    Total params
43.893    Total estimated model params size (MB)]


Training: |          | 0/? [00:00<?, ?it/s]

Real images shape: torch.Size([1, 2, 224, 224])
Conditioned images shape: torch.Size([1, 1, 224, 224])


TypeError: Critic.forward() missing 1 required positional argument: 'l'