PBJ commited on
Commit
3776899
1 Parent(s): bec7243

Upload 2 files

Browse files
Files changed (2) hide show
  1. requirements.txt +17 -0
  2. web_app.py +367 -0
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ PIL
3
+ cv2
4
+ os
5
+ glob
6
+ time
7
+ numpy
8
+ pathlib
9
+ tqdm.notebook
10
+ matplotlib.pyplot
11
+ skimage.color
12
+ torch
13
+ torchvision
14
+ torchvision.utils
15
+ torch.utils.data
16
+ torchinfo
17
+ fastai==2.4
web_app.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import cv2 as cv
4
+
5
+ # ---------Backend--------------------------------------------------------------
6
+
7
+ import os
8
+ import glob
9
+ import time
10
+ import numpy as np
11
+ from PIL import Image
12
+ from pathlib import Path
13
+ from tqdm.notebook import tqdm
14
+ import matplotlib.pyplot as plt
15
+ from skimage.color import rgb2lab, lab2rgb
16
+
17
+ # pip install fastai==2.4
18
+
19
+ import torch
20
+ from torch import nn, optim
21
+ from torchvision import transforms
22
+ from torchvision.utils import make_grid
23
+ from torch.utils.data import Dataset, DataLoader
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ use_colab = None
26
+
27
+ SIZE = 256
28
+ class ColorizationDataset(Dataset):
29
+ def __init__(self, paths, split='train'):
30
+ if split == 'train':
31
+ self.transforms = transforms.Compose([
32
+ transforms.Resize((SIZE, SIZE), Image.BICUBIC),
33
+ transforms.RandomHorizontalFlip(), # A little data augmentation!
34
+ ])
35
+ elif split == 'val':
36
+ self.transforms = transforms.Resize((SIZE, SIZE), Image.BICUBIC)
37
+
38
+ self.split = split
39
+ self.size = SIZE
40
+ self.paths = paths
41
+
42
+ def __getitem__(self, idx):
43
+ img = Image.open(self.paths[idx]).convert("RGB")
44
+ img = self.transforms(img)
45
+ img = np.array(img)
46
+ img_lab = rgb2lab(img).astype("float32") # Converting RGB to L*a*b
47
+ img_lab = transforms.ToTensor()(img_lab)
48
+ L = img_lab[[0], ...] / 50. - 1. # Between -1 and 1
49
+ ab = img_lab[[1, 2], ...] / 110. # Between -1 and 1
50
+
51
+ return {'L': L, 'ab': ab}
52
+
53
+ def __len__(self):
54
+ return len(self.paths)
55
+
56
+ def make_dataloaders(batch_size=16, n_workers=4, pin_memory=True, **kwargs): # A handy function to make our dataloaders
57
+ dataset = ColorizationDataset(**kwargs)
58
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers,
59
+ pin_memory=pin_memory)
60
+ return dataloader
61
+
62
+ class UnetBlock(nn.Module):
63
+ def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False,
64
+ innermost=False, outermost=False):
65
+ super().__init__()
66
+ self.outermost = outermost
67
+ if input_c is None: input_c = nf
68
+ downconv = nn.Conv2d(input_c, ni, kernel_size=4,
69
+ stride=2, padding=1, bias=False)
70
+ downrelu = nn.LeakyReLU(0.2, True)
71
+ downnorm = nn.BatchNorm2d(ni)
72
+ uprelu = nn.ReLU(True)
73
+ upnorm = nn.BatchNorm2d(nf)
74
+
75
+ if outermost:
76
+ upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
77
+ stride=2, padding=1)
78
+ down = [downconv]
79
+ up = [uprelu, upconv, nn.Tanh()]
80
+ model = down + [submodule] + up
81
+ elif innermost:
82
+ upconv = nn.ConvTranspose2d(ni, nf, kernel_size=4,
83
+ stride=2, padding=1, bias=False)
84
+ down = [downrelu, downconv]
85
+ up = [uprelu, upconv, upnorm]
86
+ model = down + up
87
+ else:
88
+ upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
89
+ stride=2, padding=1, bias=False)
90
+ down = [downrelu, downconv, downnorm]
91
+ up = [uprelu, upconv, upnorm]
92
+ if dropout: up += [nn.Dropout(0.5)]
93
+ model = down + [submodule] + up
94
+ self.model = nn.Sequential(*model)
95
+
96
+ def forward(self, x):
97
+ if self.outermost:
98
+ return self.model(x)
99
+ else:
100
+ return torch.cat([x, self.model(x)], 1)
101
+
102
+ class Unet(nn.Module):
103
+ def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64):
104
+ super().__init__()
105
+ unet_block = UnetBlock(num_filters * 8, num_filters * 8, innermost=True)
106
+ for _ in range(n_down - 5):
107
+ unet_block = UnetBlock(num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True)
108
+ out_filters = num_filters * 8
109
+ for _ in range(3):
110
+ unet_block = UnetBlock(out_filters // 2, out_filters, submodule=unet_block)
111
+ out_filters //= 2
112
+ self.model = UnetBlock(output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True)
113
+
114
+ def forward(self, x):
115
+ return self.model(x)
116
+
117
+ class PatchDiscriminator(nn.Module):
118
+ def __init__(self, input_c, num_filters=64, n_down=3):
119
+ super().__init__()
120
+ model = [self.get_layers(input_c, num_filters, norm=False)]
121
+ model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down-1) else 2)
122
+ for i in range(n_down)] # the 'if' statement is taking care of not using
123
+ # stride of 2 for the last block in this loop
124
+ model += [self.get_layers(num_filters * 2 ** n_down, 1, s=1, norm=False, act=False)] # Make sure to not use normalization or
125
+ # activation for the last layer of the model
126
+ self.model = nn.Sequential(*model)
127
+
128
+ def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True): # when needing to make some repeatitive blocks of layers,
129
+ layers = [nn.Conv2d(ni, nf, k, s, p, bias=not norm)] # it's always helpful to make a separate method for that purpose
130
+ if norm: layers += [nn.BatchNorm2d(nf)]
131
+ if act: layers += [nn.LeakyReLU(0.2, True)]
132
+ return nn.Sequential(*layers)
133
+
134
+ def forward(self, x):
135
+ return self.model(x)
136
+
137
+ class GANLoss(nn.Module):
138
+ def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0):
139
+ super().__init__()
140
+ self.register_buffer('real_label', torch.tensor(real_label))
141
+ self.register_buffer('fake_label', torch.tensor(fake_label))
142
+ if gan_mode == 'vanilla':
143
+ self.loss = nn.BCEWithLogitsLoss()
144
+ elif gan_mode == 'lsgan':
145
+ self.loss = nn.MSELoss()
146
+
147
+ def get_labels(self, preds, target_is_real):
148
+ if target_is_real:
149
+ labels = self.real_label
150
+ else:
151
+ labels = self.fake_label
152
+ return labels.expand_as(preds)
153
+
154
+ def __call__(self, preds, target_is_real):
155
+ labels = self.get_labels(preds, target_is_real)
156
+ loss = self.loss(preds, labels)
157
+ return loss
158
+
159
+ def init_weights(net, init='norm', gain=0.02):
160
+
161
+ def init_func(m):
162
+ classname = m.__class__.__name__
163
+ if hasattr(m, 'weight') and 'Conv' in classname:
164
+ if init == 'norm':
165
+ nn.init.normal_(m.weight.data, mean=0.0, std=gain)
166
+ elif init == 'xavier':
167
+ nn.init.xavier_normal_(m.weight.data, gain=gain)
168
+ elif init == 'kaiming':
169
+ nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
170
+
171
+ if hasattr(m, 'bias') and m.bias is not None:
172
+ nn.init.constant_(m.bias.data, 0.0)
173
+ elif 'BatchNorm2d' in classname:
174
+ nn.init.normal_(m.weight.data, 1., gain)
175
+ nn.init.constant_(m.bias.data, 0.)
176
+
177
+ net.apply(init_func)
178
+ print(f"model initialized with {init} initialization")
179
+ return net
180
+
181
+ def init_model(model, device):
182
+ model = model.to(device)
183
+ model = init_weights(model)
184
+ return model
185
+
186
+ class MainModel(nn.Module):
187
+ def __init__(self, net_G=None, lr_G=2e-4, lr_D=2e-4,
188
+ beta1=0.5, beta2=0.999, lambda_L1=100.):
189
+ super().__init__()
190
+
191
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
192
+ self.lambda_L1 = lambda_L1
193
+
194
+ if net_G is None:
195
+ self.net_G = init_model(Unet(input_c=1, output_c=2, n_down=8, num_filters=64), self.device)
196
+ else:
197
+ self.net_G = net_G.to(self.device)
198
+ self.net_D = init_model(PatchDiscriminator(input_c=3, n_down=3, num_filters=64), self.device)
199
+ self.GANcriterion = GANLoss(gan_mode='vanilla').to(self.device)
200
+ self.L1criterion = nn.L1Loss()
201
+ self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
202
+ self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))
203
+
204
+ def set_requires_grad(self, model, requires_grad=True):
205
+ for p in model.parameters():
206
+ p.requires_grad = requires_grad
207
+
208
+ def setup_input(self, data):
209
+ self.L = data['L'].to(self.device)
210
+ self.ab = data['ab'].to(self.device)
211
+
212
+ def forward(self):
213
+ self.fake_color = self.net_G(self.L)
214
+
215
+ def backward_D(self):
216
+ fake_image = torch.cat([self.L, self.fake_color], dim=1)
217
+ fake_preds = self.net_D(fake_image.detach())
218
+ self.loss_D_fake = self.GANcriterion(fake_preds, False)
219
+ real_image = torch.cat([self.L, self.ab], dim=1)
220
+ real_preds = self.net_D(real_image)
221
+ self.loss_D_real = self.GANcriterion(real_preds, True)
222
+ self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
223
+ self.loss_D.backward()
224
+
225
+ def backward_G(self):
226
+ fake_image = torch.cat([self.L, self.fake_color], dim=1)
227
+ fake_preds = self.net_D(fake_image)
228
+ self.loss_G_GAN = self.GANcriterion(fake_preds, True)
229
+ self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1
230
+ self.loss_G = self.loss_G_GAN + self.loss_G_L1
231
+ self.loss_G.backward()
232
+
233
+ def optimize(self):
234
+ self.forward()
235
+ self.net_D.train()
236
+ self.set_requires_grad(self.net_D, True)
237
+ self.opt_D.zero_grad()
238
+ self.backward_D()
239
+ self.opt_D.step()
240
+
241
+ self.net_G.train()
242
+ self.set_requires_grad(self.net_D, False)
243
+ self.opt_G.zero_grad()
244
+ self.backward_G()
245
+ self.opt_G.step()
246
+
247
+ class AverageMeter:
248
+ def __init__(self):
249
+ self.reset()
250
+
251
+ def reset(self):
252
+ self.count, self.avg, self.sum = [0.] * 3
253
+
254
+ def update(self, val, count=1):
255
+ self.count += count
256
+ self.sum += count * val
257
+ self.avg = self.sum / self.count
258
+
259
+ def create_loss_meters():
260
+ loss_D_fake = AverageMeter()
261
+ loss_D_real = AverageMeter()
262
+ loss_D = AverageMeter()
263
+ loss_G_GAN = AverageMeter()
264
+ loss_G_L1 = AverageMeter()
265
+ loss_G = AverageMeter()
266
+
267
+ return {'loss_D_fake': loss_D_fake,
268
+ 'loss_D_real': loss_D_real,
269
+ 'loss_D': loss_D,
270
+ 'loss_G_GAN': loss_G_GAN,
271
+ 'loss_G_L1': loss_G_L1,
272
+ 'loss_G': loss_G}
273
+
274
+ def update_losses(model, loss_meter_dict, count):
275
+ for loss_name, loss_meter in loss_meter_dict.items():
276
+ loss = getattr(model, loss_name)
277
+ loss_meter.update(loss.item(), count=count)
278
+
279
+ def lab_to_rgb(L, ab):
280
+ """
281
+ Takes a batch of images
282
+ """
283
+
284
+ L = (L + 1.) * 50.
285
+ ab = ab * 110.
286
+ Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
287
+ rgb_imgs = []
288
+ for img in Lab:
289
+ img_rgb = lab2rgb(img)
290
+ rgb_imgs.append(img_rgb)
291
+ return np.stack(rgb_imgs, axis=0)
292
+
293
+ def visualize(model, data, dims):
294
+ model.net_G.eval()
295
+ with torch.no_grad():
296
+ model.setup_input(data)
297
+ model.forward()
298
+ model.net_G.train()
299
+ fake_color = model.fake_color.detach()
300
+ real_color = model.ab
301
+ L = model.L
302
+ fake_imgs = lab_to_rgb(L, fake_color)
303
+ real_imgs = lab_to_rgb(L, real_color)
304
+ for i in range(1):
305
+ # t_img = transforms.Resize((dims[0], dims[1]))(t_img)
306
+ img = Image.fromarray(np.uint8(fake_imgs[i]))
307
+ img = cv.resize(fake_imgs[i], dsize=(dims[1], dims[0]), interpolation=cv.INTER_CUBIC)
308
+ # st.text(f"Size of fake image {fake_imgs[i].shape} \n Type of image = {type(fake_imgs[i])}")
309
+ st.image(img, caption="Output image", use_column_width='auto', clamp=True)
310
+
311
+ def log_results(loss_meter_dict):
312
+ for loss_name, loss_meter in loss_meter_dict.items():
313
+ print(f"{loss_name}: {loss_meter.avg:.5f}")
314
+
315
+ # pip install fastai==2.4
316
+ from fastai.vision.learner import create_body
317
+ from torchvision.models.resnet import resnet18
318
+ from fastai.vision.models.unet import DynamicUnet
319
+
320
+ def build_res_unet(n_input=1, n_output=2, size=256):
321
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
322
+ body = create_body(resnet18(), pretrained=True, n_in=n_input, cut=-2)
323
+ net_G = DynamicUnet(body, n_output, (size, size)).to(device)
324
+ return net_G
325
+
326
+ net_G = build_res_unet(n_input=1, n_output=2, size=256)
327
+ net_G.load_state_dict(torch.load("res18-unet.pt", map_location=device))
328
+ model = MainModel(net_G=net_G)
329
+ model.load_state_dict(torch.load("final_model_weights.pt", map_location=device))
330
+
331
+ class MyDataset(torch.utils.data.Dataset):
332
+ def __init__(self, img_list):
333
+ super(MyDataset, self).__init__()
334
+ self.img_list = img_list
335
+ self.augmentations = transforms.Resize((SIZE, SIZE), Image.BICUBIC)
336
+
337
+
338
+ def __len__(self):
339
+ return len(self.img_list)
340
+
341
+ def __getitem__(self, idx):
342
+ img = self.img_list[idx]
343
+ img = self.augmentations(img)
344
+ img = np.array(img)
345
+ img_lab = rgb2lab(img).astype("float32") # Converting RGB to L*a*b
346
+ img_lab = transforms.ToTensor()(img_lab)
347
+ L = img_lab[[0], ...] / 50. - 1. # Between -1 and 1
348
+ ab = img_lab[[1, 2], ...] / 110.
349
+ return {'L': L, 'ab': ab}
350
+
351
+ def make_dataloaders2(batch_size=16, n_workers=4, pin_memory=True, **kwargs): # A handy function to make our dataloaders
352
+ dataset = MyDataset(**kwargs)
353
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers,
354
+ pin_memory=pin_memory)
355
+ return dataloader
356
+
357
+ file_up = st.file_uploader("Upload an jpg image", type="jpg")
358
+ if file_up is not None:
359
+ im = Image.open(file_up)
360
+ st.text(body=f"Size of uploaded image {im.shape}")
361
+ a = im.shape
362
+ st.image(im, caption="Uploaded Image.", use_column_width='auto')
363
+ test_dl = make_dataloaders2(img_list=[im])
364
+ for data in test_dl:
365
+ model.setup_input(data)
366
+ model.optimize()
367
+ visualize(model, data, a)