Ved Gupta commited on
Commit
77b60c4
·
1 Parent(s): 5ec91c1

initial commit

Browse files
Pipfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [[source]]
2
+ url = "https://pypi.org/simple"
3
+ verify_ssl = true
4
+ name = "pypi"
5
+
6
+ [packages]
7
+
8
+ [dev-packages]
9
+
10
+ [requires]
11
+ python_version = "3.10"
image-colorization-using-gan-main.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
main.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.filterwarnings("ignore")
3
+
4
+ import os
5
+ import sys
6
+ import glob
7
+ import time
8
+ import numpy as np
9
+ from PIL import Image
10
+ from pathlib import Path
11
+ from tqdm.notebook import tqdm
12
+ import matplotlib.pyplot as plt
13
+ from skimage.color import rgb2lab, lab2rgb
14
+
15
+ import torch
16
+ from torch import nn, optim
17
+ from torchvision import transforms
18
+ from torchvision.utils import make_grid
19
+ from torch.utils.data import Dataset, DataLoader
20
+
21
+
22
+ from utility import *
23
+ from model import *
24
+
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+
27
+ model_path = "model/ImageColorizationModel.pth"
28
+
29
+
30
+ model = None
31
+ if not os.path.exists(model_path) :
32
+ print("Model not find")
33
+ download_from_drive()
34
+ print("Model Downloaded")
35
+ else:
36
+ model = load_model(model_class=MainModel , file_path=model_path)
37
+ print("Model Loaded")
38
+
39
+ def predict_and_return_image(image):
40
+ data = create_lab_tensors(image)
41
+ model.net_G.eval()
42
+ with torch.no_grad():
43
+ model.setup_input(data)
44
+ model.forward()
45
+ fake_color = model.fake_color.detach()
46
+ L = model.L
47
+ fake_imgs = lab_to_rgb(L, fake_color)
48
+ return fake_imgs[0]
49
+
50
+
51
+
52
+
53
+
54
+ title = "Black&White to Color image"
55
+ description = "Transforming Black & White Image in to colored image. Upload a black and white image to see it colorized by our deep learning model."
56
+
57
+ gr.Interface(
58
+ fn=predict_and_return_image,
59
+ title=title,
60
+ description=description,
61
+ inputs=[gr.Image(label="Gray Scale Image")],
62
+ outputs=[
63
+ gr.Image(label="Predicted Colored Image")
64
+ ],
65
+ ).launch(share=True, debug=True)
model/Discriminator.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import time
4
+ import numpy as np
5
+ from PIL import Image
6
+ from pathlib import Path
7
+ from tqdm.notebook import tqdm
8
+ import matplotlib.pyplot as plt
9
+ from skimage.color import rgb2lab, lab2rgb
10
+
11
+ import torch
12
+ from torch import nn, optim
13
+ from torchvision import transforms
14
+ from torchvision.utils import make_grid
15
+ from torch.utils.data import Dataset, DataLoader
16
+
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
+ class PatchDiscriminator(nn.Module):
20
+ def __init__(self, input_c, num_filters=64, n_down=3):
21
+ super().__init__()
22
+ model = [self.get_layers(input_c, num_filters, norm=False)]
23
+ model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down-1) else 2)
24
+ for i in range(n_down)] # the 'if' statement is taking care of not using
25
+ # stride of 2 for the last block in this loop
26
+ model += [self.get_layers(num_filters * 2 ** n_down, 1, s=1, norm=False, act=False)] # Make sure to not use normalization or
27
+ # activation for the last layer of the model
28
+ self.model = nn.Sequential(*model)
29
+
30
+ def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True): # when needing to make some repeatitive blocks of layers,
31
+ layers = [nn.Conv2d(ni, nf, k, s, p, bias=not norm)] # it's always helpful to make a separate method for that purpose
32
+ if norm: layers += [nn.BatchNorm2d(nf)]
33
+ if act: layers += [nn.LeakyReLU(0.2, True)]
34
+ return nn.Sequential(*layers)
35
+
36
+ def forward(self, x):
37
+ return self.model(x)
model/Generator.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import time
4
+ import numpy as np
5
+ from PIL import Image
6
+ from pathlib import Path
7
+ from tqdm.notebook import tqdm
8
+ import matplotlib.pyplot as plt
9
+ from skimage.color import rgb2lab, lab2rgb
10
+
11
+ import torch
12
+ from torch import nn, optim
13
+ from torchvision import transforms
14
+ from torchvision.utils import make_grid
15
+ from torch.utils.data import Dataset, DataLoader
16
+
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
+ class UnetBlock(nn.Module):
20
+ def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False,
21
+ innermost=False, outermost=False):
22
+ super().__init__()
23
+ self.outermost = outermost
24
+ if input_c is None: input_c = nf
25
+ downconv = nn.Conv2d(input_c, ni, kernel_size=4,
26
+ stride=2, padding=1, bias=False)
27
+ downrelu = nn.LeakyReLU(0.2, True)
28
+ downnorm = nn.BatchNorm2d(ni)
29
+ uprelu = nn.ReLU(True)
30
+ upnorm = nn.BatchNorm2d(nf)
31
+
32
+ if outermost:
33
+ upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
34
+ stride=2, padding=1)
35
+ down = [downconv]
36
+ up = [uprelu, upconv, nn.Tanh()]
37
+ model = down + [submodule] + up
38
+ elif innermost:
39
+ upconv = nn.ConvTranspose2d(ni, nf, kernel_size=4,
40
+ stride=2, padding=1, bias=False)
41
+ down = [downrelu, downconv]
42
+ up = [uprelu, upconv, upnorm]
43
+ model = down + up
44
+ else:
45
+ upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
46
+ stride=2, padding=1, bias=False)
47
+ down = [downrelu, downconv, downnorm]
48
+ up = [uprelu, upconv, upnorm]
49
+ if dropout: up += [nn.Dropout(0.5)]
50
+ model = down + [submodule] + up
51
+ self.model = nn.Sequential(*model)
52
+
53
+ def forward(self, x):
54
+ if self.outermost:
55
+ return self.model(x)
56
+ else:
57
+ return torch.cat([x, self.model(x)], 1)
58
+
59
+
60
+ class Unet(nn.Module):
61
+ def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64):
62
+ super().__init__()
63
+ unet_block = UnetBlock(num_filters * 8, num_filters * 8, innermost=True)
64
+ for _ in range(n_down - 5):
65
+ unet_block = UnetBlock(num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True)
66
+ out_filters = num_filters * 8
67
+ for _ in range(3):
68
+ unet_block = UnetBlock(out_filters // 2, out_filters, submodule=unet_block)
69
+ out_filters //= 2
70
+ self.model = UnetBlock(output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True)
71
+
72
+ def forward(self, x):
73
+ return self.model(x)
model/__init__.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import time
4
+ import numpy as np
5
+ from PIL import Image
6
+ from pathlib import Path
7
+ from tqdm.notebook import tqdm
8
+ import matplotlib.pyplot as plt
9
+ from skimage.color import rgb2lab, lab2rgb
10
+
11
+ import torch
12
+ from torch import nn, optim
13
+ from torchvision import transforms
14
+ from torchvision.utils import make_grid
15
+ from torch.utils.data import Dataset, DataLoader
16
+
17
+ from .Generator import UnetBlock , Unet
18
+ from .Discriminator import PatchDiscriminator
19
+ from .weights import init_weights
20
+
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+
23
+ def init_model(model, device):
24
+ model = model.to(device)
25
+ model = init_weights(model)
26
+ return model
27
+
28
+ class MainModel(nn.Module):
29
+ def __init__(self, net_G=None, lr_G=2e-4, lr_D=2e-4,
30
+ beta1=0.5, beta2=0.999, lambda_L1=100.):
31
+ super().__init__()
32
+
33
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+ self.lambda_L1 = lambda_L1
35
+
36
+ if net_G is None:
37
+ self.net_G = init_model(Unet(input_c=1, output_c=2, n_down=8, num_filters=64), self.device)
38
+ else:
39
+ self.net_G = net_G.to(self.device)
40
+ self.net_D = init_model(PatchDiscriminator(input_c=3, n_down=3, num_filters=64), self.device)
41
+ self.GANcriterion = GANLoss(gan_mode='vanilla').to(self.device)
42
+ self.L1criterion = nn.L1Loss()
43
+ self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
44
+ self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))
45
+
46
+ def set_requires_grad(self, model, requires_grad=True):
47
+ for p in model.parameters():
48
+ p.requires_grad = requires_grad
49
+
50
+ def setup_input(self, data):
51
+ self.L = data['L'].to(self.device)
52
+ self.ab = data['ab'].to(self.device)
53
+
54
+ def forward(self):
55
+ self.fake_color = self.net_G(self.L)
56
+
57
+ def backward_D(self):
58
+ fake_image = torch.cat([self.L, self.fake_color], dim=1)
59
+ fake_preds = self.net_D(fake_image.detach())
60
+ self.loss_D_fake = self.GANcriterion(fake_preds, False)
61
+ real_image = torch.cat([self.L, self.ab], dim=1)
62
+ real_preds = self.net_D(real_image)
63
+ self.loss_D_real = self.GANcriterion(real_preds, True)
64
+ self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
65
+ self.loss_D.backward()
66
+
67
+ def backward_G(self):
68
+ fake_image = torch.cat([self.L, self.fake_color], dim=1)
69
+ fake_preds = self.net_D(fake_image)
70
+ self.loss_G_GAN = self.GANcriterion(fake_preds, True)
71
+ self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1
72
+ self.loss_G = self.loss_G_GAN + self.loss_G_L1
73
+ self.loss_G.backward()
74
+
75
+ def optimize(self):
76
+ self.forward()
77
+ self.net_D.train()
78
+ self.set_requires_grad(self.net_D, True)
79
+ self.opt_D.zero_grad()
80
+ self.backward_D()
81
+ self.opt_D.step()
82
+
83
+ self.net_G.train()
84
+ self.set_requires_grad(self.net_D, False)
85
+ self.opt_G.zero_grad()
86
+ self.backward_G()
87
+ self.opt_G.step()
model/weights.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import time
4
+ import numpy as np
5
+ from PIL import Image
6
+ from pathlib import Path
7
+ from tqdm.notebook import tqdm
8
+ import matplotlib.pyplot as plt
9
+ from skimage.color import rgb2lab, lab2rgb
10
+
11
+ import torch
12
+ from torch import nn, optim
13
+ from torchvision import transforms
14
+ from torchvision.utils import make_grid
15
+ from torch.utils.data import Dataset, DataLoader
16
+
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
+ def init_weights(net, init='norm', gain=0.02):
20
+
21
+ def init_func(m):
22
+ classname = m.__class__.__name__
23
+ if hasattr(m, 'weight') and 'Conv' in classname:
24
+ if init == 'norm':
25
+ nn.init.normal_(m.weight.data, mean=0.0, std=gain)
26
+ elif init == 'xavier':
27
+ nn.init.xavier_normal_(m.weight.data, gain=gain)
28
+ elif init == 'kaiming':
29
+ nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
30
+
31
+ if hasattr(m, 'bias') and m.bias is not None:
32
+ nn.init.constant_(m.bias.data, 0.0)
33
+ elif 'BatchNorm2d' in classname:
34
+ nn.init.normal_(m.weight.data, 1., gain)
35
+ nn.init.constant_(m.bias.data, 0.)
36
+
37
+ net.apply(init_func)
38
+ print(f"model initialized with {init} initialization")
39
+ return net
utility/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .helper import *
utility/helper.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import time
4
+ import numpy as np
5
+ from PIL import Image
6
+ from pathlib import Path
7
+ from tqdm.notebook import tqdm
8
+ import matplotlib.pyplot as plt
9
+ from skimage.color import rgb2lab, lab2rgb
10
+
11
+ import torch
12
+ from torch import nn, optim
13
+ from torchvision import transforms
14
+ from torchvision.utils import make_grid
15
+ from torch.utils.data import Dataset, DataLoader
16
+
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
+ import requests
20
+ import gdown
21
+
22
+ def download_from_drive():
23
+ url = "https://drive.google.com/uc?id=1EhuMET76c02VFyRW8Pie7BwNCDHmQiad"
24
+ try:
25
+ output = "model/ImageColorizationModel.pth"
26
+ gdown.download(url, output, quiet=False)
27
+ return True
28
+ except:
29
+ print("Error Occured in Downloading model from Gdrive")
30
+ return False
31
+
32
+
33
+ class AverageMeter:
34
+ def __init__(self):
35
+ self.reset()
36
+
37
+ def reset(self):
38
+ self.count, self.avg, self.sum = [0.] * 3
39
+
40
+ def update(self, val, count=1):
41
+ self.count += count
42
+ self.sum += count * val
43
+ self.avg = self.sum / self.count
44
+
45
+ def create_loss_meters():
46
+ loss_D_fake = AverageMeter()
47
+ loss_D_real = AverageMeter()
48
+ loss_D = AverageMeter()
49
+ loss_G_GAN = AverageMeter()
50
+ loss_G_L1 = AverageMeter()
51
+ loss_G = AverageMeter()
52
+
53
+ return {'loss_D_fake': loss_D_fake,
54
+ 'loss_D_real': loss_D_real,
55
+ 'loss_D': loss_D,
56
+ 'loss_G_GAN': loss_G_GAN,
57
+ 'loss_G_L1': loss_G_L1,
58
+ 'loss_G': loss_G}
59
+
60
+ def update_losses(model, loss_meter_dict, count):
61
+ for loss_name, loss_meter in loss_meter_dict.items():
62
+ loss = getattr(model, loss_name)
63
+ loss_meter.update(loss.item(), count=count)
64
+
65
+ def lab_to_rgb(L, ab):
66
+ """
67
+ Takes a batch of images
68
+ """
69
+
70
+ L = (L + 1.) * 50.
71
+ ab = ab * 110.
72
+ Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
73
+ rgb_imgs = []
74
+ for img in Lab:
75
+ img_rgb = lab2rgb(img)
76
+ rgb_imgs.append(img_rgb)
77
+ return np.stack(rgb_imgs, axis=0)
78
+
79
+ def visualize(model, data, save=True):
80
+ model.net_G.eval()
81
+ with torch.no_grad():
82
+ model.setup_input(data)
83
+ model.forward()
84
+ model.net_G.train()
85
+ fake_color = model.fake_color.detach()
86
+ real_color = model.ab
87
+ L = model.L
88
+ fake_imgs = lab_to_rgb(L, fake_color)
89
+ real_imgs = lab_to_rgb(L, real_color)
90
+ fig = plt.figure(figsize=(15, 8))
91
+ for i in range(5):
92
+ ax = plt.subplot(3, 5, i + 1)
93
+ ax.imshow(L[i][0].cpu(), cmap='gray')
94
+ ax.axis("off")
95
+ ax = plt.subplot(3, 5, i + 1 + 5)
96
+ ax.imshow(fake_imgs[i])
97
+ ax.axis("off")
98
+ ax = plt.subplot(3, 5, i + 1 + 10)
99
+ ax.imshow(real_imgs[i])
100
+ ax.axis("off")
101
+ plt.show()
102
+ if save:
103
+ fig.savefig(f"colorization_{time.time()}.png")
104
+
105
+ def log_results(loss_meter_dict):
106
+ for loss_name, loss_meter in loss_meter_dict.items():
107
+ print(f"{loss_name}: {loss_meter.avg:.5f}")
108
+
109
+ def create_lab_tensors(image):
110
+ """
111
+ This function receives an image path or a direct image input and creates a dictionary of L and ab tensors.
112
+ Args:
113
+ - image: either a path to the image file or a direct image input.
114
+ Returns:
115
+ - lab_dict: dictionary containing the L and ab tensors.
116
+ """
117
+ if isinstance(image, str):
118
+ # Open the image and convert it to RGB format
119
+ img = Image.open(image).convert('RGB')
120
+ else:
121
+ img = image.convert('RGB')
122
+
123
+ custom_transforms = transforms.Compose([
124
+ transforms.Resize((SIZE, SIZE), Image.BICUBIC),
125
+ transforms.RandomHorizontalFlip(), # A little data augmentation!
126
+ ])
127
+ img = custom_transforms(img)
128
+ img = np.array(img)
129
+ img_lab = rgb2lab(img).astype("float32") # Converting RGB to L*a*b
130
+ img_lab = transforms.ToTensor()(img_lab)
131
+ L = img_lab[[0], ...] / 50. - 1. # Between -1 and 1
132
+ L = L.unsqueeze(0)
133
+ ab = img_lab[[1, 2], ...] / 110. # Between -1 and 1
134
+ return {'L': L, 'ab': ab}
135
+
136
+
137
+ def predict_and_visualize_single_image(model, data, save=True):
138
+ model.net_G.eval()
139
+ with torch.no_grad():
140
+ model.setup_input(data)
141
+ model.forward()
142
+ fake_color = model.fake_color.detach()
143
+ L = model.L
144
+ fake_imgs = lab_to_rgb(L, fake_color)
145
+ fig, axs = plt.subplots(1, 2, figsize=(8, 4))
146
+ axs[0].imshow(L[0][0].cpu(), cmap='gray')
147
+ axs[0].set_title("Grey Image")
148
+ axs[0].axis('off')
149
+
150
+ axs[1].imshow(fake_imgs[0])
151
+ axs[1].set_title("Colored Image")
152
+ axs[1].axis('off')
153
+ plt.show()
154
+ if save:
155
+ fig.savefig(f"colorization_{time.time()}.png")
156
+
157
+ def predict_color(model , image , save=False):
158
+ """
159
+ This function receives an image path or a direct image input and creates a dictionary of L and ab tensors.
160
+ Args:
161
+ - model : Pytorch Gray Scale to Colorization Model
162
+ - image: either a path to the image file or a direct image input.
163
+ """
164
+ data = create_lab_tensors(image)
165
+ predict_and_visualize_single_image(model, data, save)
166
+
167
+
168
+ def load_model(model_class, file_path):
169
+ """
170
+ Load PyTorch model from file.
171
+
172
+ Args:
173
+ model_class (torch.nn.Module): PyTorch model class to load.
174
+ file_path (str): File path to load the model from.
175
+
176
+ Returns:
177
+ model (torch.nn.Module): Loaded PyTorch model.
178
+ """
179
+ model = model_class()
180
+ model.load_state_dict(torch.load(file_path))
181
+ return model