PushkarA07 commited on
Commit
61ffd48
·
1 Parent(s): b1c39fd

initial commit

Browse files
Files changed (3) hide show
  1. main-model.pt +3 -0
  2. res18-unet.pt +3 -0
  3. web_app.py +417 -0
main-model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9117325fed27284a7e1fcd27b01ad3d6840ec6f14fe792c1004e8329eca264ea
3
+ size 135587239
res18-unet.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:245439ac4c6bacd91752800f98a425922fb4fc73fd107a3772ae6dab0e2ea3ca
3
+ size 124507223
web_app.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.vision.models.unet import DynamicUnet
2
+ from torchvision.models.resnet import resnet18
3
+ from fastai.vision.learner import create_body
4
+ import streamlit as st
5
+ from PIL import Image
6
+ import cv2 as cv
7
+
8
+ # ---------Backend--------------------------------------------------------------
9
+
10
+ import os
11
+ import glob
12
+ import time
13
+ import numpy as np
14
+ from PIL import Image
15
+ from pathlib import Path
16
+ from tqdm.notebook import tqdm
17
+ import matplotlib.pyplot as plt
18
+ from skimage.color import rgb2lab, lab2rgb
19
+
20
+ # pip install fastai==2.4
21
+
22
+ import torch
23
+ from torch import nn, optim
24
+ from torchvision import transforms
25
+ from torchvision.utils import make_grid
26
+ from torch.utils.data import Dataset, DataLoader
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ use_colab = None
29
+
30
+ SIZE = 256
31
+
32
+
33
+ class ColorizationDataset(Dataset):
34
+ def __init__(self, paths, split='train'):
35
+ if split == 'train':
36
+ self.transforms = transforms.Compose([
37
+ transforms.Resize((SIZE, SIZE), Image.BICUBIC),
38
+ transforms.RandomHorizontalFlip(), # A little data augmentation!
39
+ ])
40
+ elif split == 'val':
41
+ self.transforms = transforms.Resize((SIZE, SIZE), Image.BICUBIC)
42
+
43
+ self.split = split
44
+ self.size = SIZE
45
+ self.paths = paths
46
+
47
+ def __getitem__(self, idx):
48
+ img = Image.open(self.paths[idx]).convert("RGB")
49
+ img = self.transforms(img)
50
+ img = np.array(img)
51
+ img_lab = rgb2lab(img).astype("float32") # Converting RGB to L*a*b
52
+ img_lab = transforms.ToTensor()(img_lab)
53
+ L = img_lab[[0], ...] / 50. - 1. # Between -1 and 1
54
+ ab = img_lab[[1, 2], ...] / 110. # Between -1 and 1
55
+
56
+ return {'L': L, 'ab': ab}
57
+
58
+ def __len__(self):
59
+ return len(self.paths)
60
+
61
+
62
+ # A handy function to make our dataloaders
63
+ def make_dataloaders(batch_size=16, n_workers=4, pin_memory=True, **kwargs):
64
+ dataset = ColorizationDataset(**kwargs)
65
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers,
66
+ pin_memory=pin_memory)
67
+ return dataloader
68
+
69
+
70
+ class UnetBlock(nn.Module):
71
+ def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False,
72
+ innermost=False, outermost=False):
73
+ super().__init__()
74
+ self.outermost = outermost
75
+ if input_c is None:
76
+ input_c = nf
77
+ downconv = nn.Conv2d(input_c, ni, kernel_size=4,
78
+ stride=2, padding=1, bias=False)
79
+ downrelu = nn.LeakyReLU(0.2, True)
80
+ downnorm = nn.BatchNorm2d(ni)
81
+ uprelu = nn.ReLU(True)
82
+ upnorm = nn.BatchNorm2d(nf)
83
+
84
+ if outermost:
85
+ upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
86
+ stride=2, padding=1)
87
+ down = [downconv]
88
+ up = [uprelu, upconv, nn.Tanh()]
89
+ model = down + [submodule] + up
90
+ elif innermost:
91
+ upconv = nn.ConvTranspose2d(ni, nf, kernel_size=4,
92
+ stride=2, padding=1, bias=False)
93
+ down = [downrelu, downconv]
94
+ up = [uprelu, upconv, upnorm]
95
+ model = down + up
96
+ else:
97
+ upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
98
+ stride=2, padding=1, bias=False)
99
+ down = [downrelu, downconv, downnorm]
100
+ up = [uprelu, upconv, upnorm]
101
+ if dropout:
102
+ up += [nn.Dropout(0.5)]
103
+ model = down + [submodule] + up
104
+ self.model = nn.Sequential(*model)
105
+
106
+ def forward(self, x):
107
+ if self.outermost:
108
+ return self.model(x)
109
+ else:
110
+ return torch.cat([x, self.model(x)], 1)
111
+
112
+
113
+ class Unet(nn.Module):
114
+ def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64):
115
+ super().__init__()
116
+ unet_block = UnetBlock(
117
+ num_filters * 8, num_filters * 8, innermost=True)
118
+ for _ in range(n_down - 5):
119
+ unet_block = UnetBlock(
120
+ num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True)
121
+ out_filters = num_filters * 8
122
+ for _ in range(3):
123
+ unet_block = UnetBlock(
124
+ out_filters // 2, out_filters, submodule=unet_block)
125
+ out_filters //= 2
126
+ self.model = UnetBlock(
127
+ output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True)
128
+
129
+ def forward(self, x):
130
+ return self.model(x)
131
+
132
+
133
+ class PatchDiscriminator(nn.Module):
134
+ def __init__(self, input_c, num_filters=64, n_down=3):
135
+ super().__init__()
136
+ model = [self.get_layers(input_c, num_filters, norm=False)]
137
+ model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down-1) else 2)
138
+ for i in range(n_down)] # the 'if' statement is taking care of not using
139
+ # stride of 2 for the last block in this loop
140
+ # Make sure to not use normalization or
141
+ model += [self.get_layers(num_filters * 2 **
142
+ n_down, 1, s=1, norm=False, act=False)]
143
+ # activation for the last layer of the model
144
+ self.model = nn.Sequential(*model)
145
+
146
+ # when needing to make some repeatitive blocks of layers,
147
+ def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True):
148
+ # it's always helpful to make a separate method for that purpose
149
+ layers = [nn.Conv2d(ni, nf, k, s, p, bias=not norm)]
150
+ if norm:
151
+ layers += [nn.BatchNorm2d(nf)]
152
+ if act:
153
+ layers += [nn.LeakyReLU(0.2, True)]
154
+ return nn.Sequential(*layers)
155
+
156
+ def forward(self, x):
157
+ return self.model(x)
158
+
159
+
160
+ class GANLoss(nn.Module):
161
+ def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0):
162
+ super().__init__()
163
+ self.register_buffer('real_label', torch.tensor(real_label))
164
+ self.register_buffer('fake_label', torch.tensor(fake_label))
165
+ if gan_mode == 'vanilla':
166
+ self.loss = nn.BCEWithLogitsLoss()
167
+ elif gan_mode == 'lsgan':
168
+ self.loss = nn.MSELoss()
169
+
170
+ def get_labels(self, preds, target_is_real):
171
+ if target_is_real:
172
+ labels = self.real_label
173
+ else:
174
+ labels = self.fake_label
175
+ return labels.expand_as(preds)
176
+
177
+ def __call__(self, preds, target_is_real):
178
+ labels = self.get_labels(preds, target_is_real)
179
+ loss = self.loss(preds, labels)
180
+ return loss
181
+
182
+
183
+ def init_weights(net, init='norm', gain=0.02):
184
+
185
+ def init_func(m):
186
+ classname = m.__class__.__name__
187
+ if hasattr(m, 'weight') and 'Conv' in classname:
188
+ if init == 'norm':
189
+ nn.init.normal_(m.weight.data, mean=0.0, std=gain)
190
+ elif init == 'xavier':
191
+ nn.init.xavier_normal_(m.weight.data, gain=gain)
192
+ elif init == 'kaiming':
193
+ nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
194
+
195
+ if hasattr(m, 'bias') and m.bias is not None:
196
+ nn.init.constant_(m.bias.data, 0.0)
197
+ elif 'BatchNorm2d' in classname:
198
+ nn.init.normal_(m.weight.data, 1., gain)
199
+ nn.init.constant_(m.bias.data, 0.)
200
+
201
+ net.apply(init_func)
202
+ print(f"model initialized with {init} initialization")
203
+ return net
204
+
205
+
206
+ def init_model(model, device):
207
+ model = model.to(device)
208
+ model = init_weights(model)
209
+ return model
210
+
211
+
212
+ class MainModel(nn.Module):
213
+ def __init__(self, net_G=None, lr_G=2e-4, lr_D=2e-4,
214
+ beta1=0.5, beta2=0.999, lambda_L1=100.):
215
+ super().__init__()
216
+
217
+ self.device = torch.device(
218
+ "cuda" if torch.cuda.is_available() else "cpu")
219
+ self.lambda_L1 = lambda_L1
220
+
221
+ if net_G is None:
222
+ self.net_G = init_model(
223
+ Unet(input_c=1, output_c=2, n_down=8, num_filters=64), self.device)
224
+ else:
225
+ self.net_G = net_G.to(self.device)
226
+ self.net_D = init_model(PatchDiscriminator(
227
+ input_c=3, n_down=3, num_filters=64), self.device)
228
+ self.GANcriterion = GANLoss(gan_mode='vanilla').to(self.device)
229
+ self.L1criterion = nn.L1Loss()
230
+ self.opt_G = optim.Adam(self.net_G.parameters(),
231
+ lr=lr_G, betas=(beta1, beta2))
232
+ self.opt_D = optim.Adam(self.net_D.parameters(),
233
+ lr=lr_D, betas=(beta1, beta2))
234
+
235
+ def set_requires_grad(self, model, requires_grad=True):
236
+ for p in model.parameters():
237
+ p.requires_grad = requires_grad
238
+
239
+ def setup_input(self, data):
240
+ self.L = data['L'].to(self.device)
241
+ self.ab = data['ab'].to(self.device)
242
+
243
+ def forward(self):
244
+ self.fake_color = self.net_G(self.L)
245
+
246
+ def backward_D(self):
247
+ fake_image = torch.cat([self.L, self.fake_color], dim=1)
248
+ fake_preds = self.net_D(fake_image.detach())
249
+ self.loss_D_fake = self.GANcriterion(fake_preds, False)
250
+ real_image = torch.cat([self.L, self.ab], dim=1)
251
+ real_preds = self.net_D(real_image)
252
+ self.loss_D_real = self.GANcriterion(real_preds, True)
253
+ self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
254
+ self.loss_D.backward()
255
+
256
+ def backward_G(self):
257
+ fake_image = torch.cat([self.L, self.fake_color], dim=1)
258
+ fake_preds = self.net_D(fake_image)
259
+ self.loss_G_GAN = self.GANcriterion(fake_preds, True)
260
+ self.loss_G_L1 = self.L1criterion(
261
+ self.fake_color, self.ab) * self.lambda_L1
262
+ self.loss_G = self.loss_G_GAN + self.loss_G_L1
263
+ self.loss_G.backward()
264
+
265
+ def optimize(self):
266
+ self.forward()
267
+ self.net_D.train()
268
+ self.set_requires_grad(self.net_D, True)
269
+ self.opt_D.zero_grad()
270
+ self.backward_D()
271
+ self.opt_D.step()
272
+
273
+ self.net_G.train()
274
+ self.set_requires_grad(self.net_D, False)
275
+ self.opt_G.zero_grad()
276
+ self.backward_G()
277
+ self.opt_G.step()
278
+
279
+
280
+ class AverageMeter:
281
+ def __init__(self):
282
+ self.reset()
283
+
284
+ def reset(self):
285
+ self.count, self.avg, self.sum = [0.] * 3
286
+
287
+ def update(self, val, count=1):
288
+ self.count += count
289
+ self.sum += count * val
290
+ self.avg = self.sum / self.count
291
+
292
+
293
+ def create_loss_meters():
294
+ loss_D_fake = AverageMeter()
295
+ loss_D_real = AverageMeter()
296
+ loss_D = AverageMeter()
297
+ loss_G_GAN = AverageMeter()
298
+ loss_G_L1 = AverageMeter()
299
+ loss_G = AverageMeter()
300
+
301
+ return {'loss_D_fake': loss_D_fake,
302
+ 'loss_D_real': loss_D_real,
303
+ 'loss_D': loss_D,
304
+ 'loss_G_GAN': loss_G_GAN,
305
+ 'loss_G_L1': loss_G_L1,
306
+ 'loss_G': loss_G}
307
+
308
+
309
+ def update_losses(model, loss_meter_dict, count):
310
+ for loss_name, loss_meter in loss_meter_dict.items():
311
+ loss = getattr(model, loss_name)
312
+ loss_meter.update(loss.item(), count=count)
313
+
314
+
315
+ def lab_to_rgb(L, ab):
316
+ """
317
+ Takes a batch of images
318
+ """
319
+
320
+ L = (L + 1.) * 50.
321
+ ab = ab * 110.
322
+ Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
323
+ rgb_imgs = []
324
+ for img in Lab:
325
+ img_rgb = lab2rgb(img)
326
+ rgb_imgs.append(img_rgb)
327
+ return np.stack(rgb_imgs, axis=0)
328
+
329
+
330
+ def visualize(model, data, dims):
331
+ model.net_G.eval()
332
+ with torch.no_grad():
333
+ model.setup_input(data)
334
+ model.forward()
335
+ model.net_G.train()
336
+ fake_color = model.fake_color.detach()
337
+ real_color = model.ab
338
+ L = model.L
339
+ fake_imgs = lab_to_rgb(L, fake_color)
340
+ real_imgs = lab_to_rgb(L, real_color)
341
+ for i in range(1):
342
+ # t_img = transforms.Resize((dims[0], dims[1]))(t_img)
343
+ img = Image.fromarray(np.uint8(fake_imgs[i]))
344
+ img = cv.resize(fake_imgs[i], dsize=(
345
+ dims[1], dims[0]), interpolation=cv.INTER_CUBIC)
346
+ # st.text(f"Size of fake image {fake_imgs[i].shape} \n Type of image = {type(fake_imgs[i])}")
347
+ st.image(img, caption="Output image",
348
+ use_column_width='auto', clamp=True)
349
+
350
+
351
+ def log_results(loss_meter_dict):
352
+ for loss_name, loss_meter in loss_meter_dict.items():
353
+ print(f"{loss_name}: {loss_meter.avg:.5f}")
354
+
355
+
356
+ # pip install fastai==2.4
357
+
358
+
359
+ def build_res_unet(n_input=1, n_output=2, size=256):
360
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
361
+ body = create_body(resnet18(), pretrained=True, n_in=n_input, cut=-2)
362
+ net_G = DynamicUnet(body, n_output, (size, size)).to(device)
363
+ return net_G
364
+
365
+
366
+ net_G = build_res_unet(n_input=1, n_output=2, size=256)
367
+ net_G.load_state_dict(torch.load("res18-unet.pt", map_location=device))
368
+ model = MainModel(net_G=net_G)
369
+ model.load_state_dict(torch.load("main-model.pt", map_location=device))
370
+
371
+
372
+ class MyDataset(torch.utils.data.Dataset):
373
+ def __init__(self, img_list):
374
+ super(MyDataset, self).__init__()
375
+ self.img_list = img_list
376
+ self.augmentations = transforms.Resize((SIZE, SIZE), Image.BICUBIC)
377
+
378
+ def __len__(self):
379
+ return len(self.img_list)
380
+
381
+ def __getitem__(self, idx):
382
+ img = self.img_list[idx]
383
+ img = self.augmentations(img)
384
+ img = np.array(img)
385
+ img_lab = rgb2lab(img).astype("float32") # Converting RGB to L*a*b
386
+ img_lab = transforms.ToTensor()(img_lab)
387
+ L = img_lab[[0], ...] / 50. - 1. # Between -1 and 1
388
+ ab = img_lab[[1, 2], ...] / 110.
389
+ return {'L': L, 'ab': ab}
390
+
391
+
392
+ # A handy function to make our dataloaders
393
+ def make_dataloaders2(batch_size=16, n_workers=4, pin_memory=True, **kwargs):
394
+ dataset = MyDataset(**kwargs)
395
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers,
396
+ pin_memory=pin_memory)
397
+ return dataloader
398
+
399
+
400
+ # st.set_option('deprecation.showfileUploaderEncoding', False)
401
+ # @st.cache(allow_output_mutation= True)
402
+ st.write("""
403
+ # Image Recolorisation
404
+ """
405
+ )
406
+ file_up = st.file_uploader("Upload an jpg image", type=["jpg", "jpeg", "png"])
407
+
408
+ if file_up is not None:
409
+ im = Image.open(file_up)
410
+ st.text(body=f"Size of uploaded image {im.shape}")
411
+ a = im.shape
412
+ st.image(im, caption="Uploaded Image.", use_column_width='auto')
413
+ test_dl = make_dataloaders2(img_list=[im])
414
+ for data in test_dl:
415
+ model.setup_input(data)
416
+ model.optimize()
417
+ visualize(model, data, a)