ChiKyi commited on
Commit
4d92358
1 Parent(s): cf2db44

update file

Browse files
dataset.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ from skimage.color import rgb2lab, lab2rgb
4
+
5
+ import torch
6
+ from torch import nn, optim
7
+ from torchvision import transforms
8
+ from torch.utils.data import Dataset, DataLoader
9
+
10
+ SIZE = 256
11
+
12
+
13
+ class ColorizationDataset(Dataset):
14
+ def __init__(self, paths, split='train'):
15
+ if split == 'train':
16
+ self.transforms = transforms.Compose([
17
+ transforms.Resize((SIZE, SIZE), Image.BICUBIC),
18
+ transforms.RandomHorizontalFlip(), # A little data augmentation!
19
+ ])
20
+ elif split == 'val':
21
+ self.transforms = transforms.Resize((SIZE, SIZE), Image.BICUBIC)
22
+
23
+ self.split = split
24
+ self.size = SIZE
25
+ self.paths = paths
26
+
27
+ def __getitem__(self, idx):
28
+ img = Image.open(self.paths[idx]).convert("RGB")
29
+ img = self.transforms(img)
30
+ img = np.array(img)
31
+ img_lab = rgb2lab(img).astype("float32") # Converting RGB to L*a*b
32
+ img_lab = transforms.ToTensor()(img_lab)
33
+ L = img_lab[[0], ...] / 50. - 1. # Between -1 and 1
34
+ ab = img_lab[[1, 2], ...] / 110. # Between -1 and 1
35
+
36
+ return {'L': L, 'ab': ab}
37
+
38
+ def __len__(self):
39
+ return len(self.paths)
40
+
41
+
42
+ def make_dataloaders(batch_size=16, n_workers=4, pin_memory=True, **kwargs): # A handy function to make our dataloaders
43
+ dataset = ColorizationDataset(**kwargs)
44
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers,
45
+ pin_memory=pin_memory)
46
+ return dataloader
infer.py DELETED
@@ -1,90 +0,0 @@
1
- import torch
2
- from PIL import Image
3
- from torchvision import transforms
4
- from matplotlib import pyplot as plt
5
- import gradio as gr
6
-
7
- from models import MainModel # Import class for your main model
8
- from utils import lab_to_rgb, build_res_unet#, build_mobile_unet # Utility to convert LAB to RGB
9
-
10
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
-
12
-
13
- def load_model(generator_model_path, colorization_model_path): #, model_type='resnet')
14
-
15
- #if model_type == 'resnet':
16
- net_G = build_res_unet(n_input=1, n_output=2, size=256)
17
- # elif model_type == 'mobilenet':
18
- # net_G = build_mobile_unet(n_input=1, n_output=2, size=256)
19
-
20
- net_G.load_state_dict(torch.load(generator_model_path, map_location=device))
21
-
22
- # Create MainModel and load weights
23
- model = MainModel(net_G=net_G)
24
- model.load_state_dict(torch.load(colorization_model_path, map_location=device))
25
-
26
- # Move model to device and set to eval mode
27
- model.to(device)
28
- model.eval()
29
-
30
- return model
31
-
32
- # Load pretrained models
33
- resnet_model = load_model(
34
- "weight/pascal_res18-unet.pt",
35
- "weight/pascal_final_model_weights.pt"
36
- # model_type='resnet'
37
- )
38
-
39
- # mobilenet_model = load_model(
40
- # "weight/mobile-unet.pt",
41
- # "weight/mobile_pascal_final_model_weights.pt",
42
- # model_type='mobilenet'
43
- # )
44
-
45
- # Transformations
46
- def preprocess_image(image):
47
- image = image.resize((256, 256))
48
- image = transforms.ToTensor()(image)[:1] * 2. - 1. # Normalize to [-1, 1]
49
- return image
50
-
51
- def postprocess_image(grayscale, prediction):
52
- return lab_to_rgb(grayscale.unsqueeze(0), prediction.cpu())[0]
53
-
54
- # Prediction function
55
- def colorize_image(input_image):
56
- # Convert input to grayscale
57
- input_image = Image.fromarray(input_image).convert('L')
58
- grayscale = preprocess_image(input_image).to(device)
59
-
60
- # Generate predictions
61
- with torch.no_grad():
62
- resnet_output = resnet_model.net_G(grayscale.unsqueeze(0))
63
- # mobilenet_output = mobilenet_model.net_G(grayscale.unsqueeze(0))
64
-
65
- # Post-process results
66
- resnet_colorized = postprocess_image(grayscale, resnet_output)
67
- # mobilenet_colorized = postprocess_image(grayscale, mobilenet_output)
68
-
69
- return (
70
- input_image, # Grayscale image
71
- resnet_colorized # ResNet18 colorized image
72
- # mobilenet_colorized # MobileNet colorized image
73
- )
74
-
75
- # Gradio Interface
76
- interface = gr.Interface(
77
- fn=colorize_image,
78
- inputs=gr.Image(type="numpy", label="Upload a Color Image"),
79
- outputs=[
80
- gr.Image(label="Grayscale Image"),
81
- gr.Image(label="Colorized Image (ResNet18)")
82
- # gr.Image(label="Colorized Image (MobileNet)")
83
- ],
84
- title="Image Colorization",
85
- description="Upload a color image"
86
- )
87
-
88
- # Launch Gradio app
89
- if __name__ == '__main__':
90
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
loss.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class GANLoss(nn.Module):
6
+ def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0):
7
+ super().__init__()
8
+ self.register_buffer('real_label', torch.tensor(real_label))
9
+ self.register_buffer('fake_label', torch.tensor(fake_label))
10
+ if gan_mode == 'vanilla':
11
+ self.loss = nn.BCEWithLogitsLoss()
12
+ elif gan_mode == 'lsgan':
13
+ self.loss = nn.MSELoss()
14
+
15
+ def get_labels(self, preds, target_is_real):
16
+ if target_is_real:
17
+ labels = self.real_label
18
+ else:
19
+ labels = self.fake_label
20
+ return labels.expand_as(preds)
21
+
22
+ def __call__(self, preds, target_is_real):
23
+ labels = self.get_labels(preds, target_is_real)
24
+ loss = self.loss(preds, labels)
25
+ return loss
models.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, optim
3
+ from loss import GANLoss
4
+
5
+
6
+ class UnetBlock(nn.Module):
7
+ def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False,
8
+ innermost=False, outermost=False):
9
+ super().__init__()
10
+ self.outermost = outermost
11
+ if input_c is None: input_c = nf
12
+ downconv = nn.Conv2d(input_c, ni, kernel_size=4,
13
+ stride=2, padding=1, bias=False)
14
+ downrelu = nn.LeakyReLU(0.2, True)
15
+ downnorm = nn.BatchNorm2d(ni)
16
+ uprelu = nn.ReLU(True)
17
+ upnorm = nn.BatchNorm2d(nf)
18
+
19
+ if outermost:
20
+ upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
21
+ stride=2, padding=1)
22
+ down = [downconv]
23
+ up = [uprelu, upconv, nn.Tanh()]
24
+ model = down + [submodule] + up
25
+ elif innermost:
26
+ upconv = nn.ConvTranspose2d(ni, nf, kernel_size=4,
27
+ stride=2, padding=1, bias=False)
28
+ down = [downrelu, downconv]
29
+ up = [uprelu, upconv, upnorm]
30
+ model = down + up
31
+ else:
32
+ upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
33
+ stride=2, padding=1, bias=False)
34
+ down = [downrelu, downconv, downnorm]
35
+ up = [uprelu, upconv, upnorm]
36
+ if dropout: up += [nn.Dropout(0.5)]
37
+ model = down + [submodule] + up
38
+ self.model = nn.Sequential(*model)
39
+
40
+ def forward(self, x):
41
+ if self.outermost:
42
+ return self.model(x)
43
+ else:
44
+ return torch.cat([x, self.model(x)], 1)
45
+
46
+
47
+ class Unet(nn.Module):
48
+ def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64):
49
+ super().__init__()
50
+ unet_block = UnetBlock(num_filters * 8, num_filters * 8, innermost=True)
51
+ for _ in range(n_down - 5):
52
+ unet_block = UnetBlock(num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True)
53
+ out_filters = num_filters * 8
54
+ for _ in range(3):
55
+ unet_block = UnetBlock(out_filters // 2, out_filters, submodule=unet_block)
56
+ out_filters //= 2
57
+ self.model = UnetBlock(output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True)
58
+
59
+ def forward(self, x):
60
+ return self.model(x)
61
+
62
+
63
+ class PatchDiscriminator(nn.Module):
64
+ def __init__(self, input_c, num_filters=64, n_down=3):
65
+ super().__init__()
66
+ model = [self.get_layers(input_c, num_filters, norm=False)]
67
+ model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down - 1) else 2)
68
+ for i in range(n_down)] # the 'if' statement is taking care of not using
69
+ # stride of 2 for the last block in this loop
70
+ model += [self.get_layers(num_filters * 2 ** n_down, 1, s=1, norm=False,
71
+ act=False)] # Make sure to not use normalization or
72
+ # activation for the last layer of the model
73
+ self.model = nn.Sequential(*model)
74
+
75
+ def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True,
76
+ act=True): # when needing to make some repeatitive blocks of layers,
77
+ layers = [
78
+ nn.Conv2d(ni, nf, k, s, p, bias=not norm)] # it's always helpful to make a separate method for that purpose
79
+ if norm: layers += [nn.BatchNorm2d(nf)]
80
+ if act: layers += [nn.LeakyReLU(0.2, True)]
81
+ return nn.Sequential(*layers)
82
+
83
+ def forward(self, x):
84
+ return self.model(x)
85
+
86
+
87
+ def init_weights(net, init='norm', gain=0.02):
88
+ def init_func(m):
89
+ classname = m.__class__.__name__
90
+ if hasattr(m, 'weight') and 'Conv' in classname:
91
+ if init == 'norm':
92
+ nn.init.normal_(m.weight.data, mean=0.0, std=gain)
93
+ elif init == 'xavier':
94
+ nn.init.xavier_normal_(m.weight.data, gain=gain)
95
+ elif init == 'kaiming':
96
+ nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
97
+
98
+ if hasattr(m, 'bias') and m.bias is not None:
99
+ nn.init.constant_(m.bias.data, 0.0)
100
+ elif 'BatchNorm2d' in classname:
101
+ nn.init.normal_(m.weight.data, 1., gain)
102
+ nn.init.constant_(m.bias.data, 0.)
103
+
104
+ net.apply(init_func)
105
+ print(f"model initialized with {init} initialization")
106
+ return net
107
+
108
+
109
+ def init_model(model, device):
110
+ model = model.to(device)
111
+ model = init_weights(model)
112
+ return model
113
+
114
+
115
+ class MainModel(nn.Module):
116
+ def __init__(self, net_G=None, lr_G=2e-4, lr_D=2e-4,
117
+ beta1=0.5, beta2=0.999, lambda_L1=100.):
118
+ super().__init__()
119
+
120
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
121
+ self.lambda_L1 = lambda_L1
122
+
123
+ if net_G is None:
124
+ self.net_G = init_model(Unet(input_c=1, output_c=2, n_down=8, num_filters=64), self.device)
125
+ else:
126
+ self.net_G = net_G.to(self.device)
127
+ self.net_D = init_model(PatchDiscriminator(input_c=3, n_down=3, num_filters=64), self.device)
128
+ self.GANcriterion = GANLoss(gan_mode='vanilla').to(self.device)
129
+ self.L1criterion = nn.L1Loss()
130
+ self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
131
+ self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))
132
+
133
+ def set_requires_grad(self, model, requires_grad=True):
134
+ for p in model.parameters():
135
+ p.requires_grad = requires_grad
136
+
137
+ def setup_input(self, data):
138
+ self.L = data['L'].to(self.device)
139
+ self.ab = data['ab'].to(self.device)
140
+
141
+ def forward(self):
142
+ self.fake_color = self.net_G(self.L)
143
+
144
+ def backward_D(self):
145
+ fake_image = torch.cat([self.L, self.fake_color], dim=1)
146
+ fake_preds = self.net_D(fake_image.detach())
147
+ self.loss_D_fake = self.GANcriterion(fake_preds, False)
148
+ real_image = torch.cat([self.L, self.ab], dim=1)
149
+ real_preds = self.net_D(real_image)
150
+ self.loss_D_real = self.GANcriterion(real_preds, True)
151
+ self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
152
+ self.loss_D.backward()
153
+
154
+ def backward_G(self):
155
+ fake_image = torch.cat([self.L, self.fake_color], dim=1)
156
+ fake_preds = self.net_D(fake_image)
157
+ self.loss_G_GAN = self.GANcriterion(fake_preds, True)
158
+ self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1
159
+ self.loss_G = self.loss_G_GAN + self.loss_G_L1
160
+ self.loss_G.backward()
161
+
162
+ def optimize(self):
163
+ self.forward()
164
+ self.net_D.train()
165
+ self.set_requires_grad(self.net_D, True)
166
+ self.opt_D.zero_grad()
167
+ self.backward_D()
168
+ self.opt_D.step()
169
+
170
+ self.net_G.train()
171
+ self.set_requires_grad(self.net_D, False)
172
+ self.opt_G.zero_grad()
173
+ self.backward_G()
174
+ self.opt_G.step()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ matplotlib
4
+ gradio
5
+ Pillow
6
+ scikit-image
7
+ numpy
8
+ scikit-learn
9
+ fastai
utils.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import numpy as np
4
+ from skimage.color import rgb2lab, lab2rgb
5
+ import matplotlib.pyplot as plt
6
+ from fastai.vision.learner import create_body
7
+ from fastai.vision.models.unet import DynamicUnet
8
+ from torchvision.models import resnet18
9
+ from torchvision.models import mobilenet_v2
10
+ import torch
11
+
12
+
13
+ class AverageMeter:
14
+ def __init__(self):
15
+ self.reset()
16
+
17
+ def reset(self):
18
+ self.count, self.avg, self.sum = [0.] * 3
19
+
20
+ def update(self, val, count=1):
21
+ self.count += count
22
+ self.sum += count * val
23
+ self.avg = self.sum / self.count
24
+
25
+ def build_res_unet(n_input=1, n_output=2, size=256):
26
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ body = create_body(resnet18(pretrained=True), n_in=n_input, cut=-2)
28
+ net_G = DynamicUnet(body, n_output, (size, size)).to(device)
29
+ return net_G
30
+
31
+ # def build_mobile_unet(n_input=1, n_output=2, size=256):
32
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ # mobilenet_model = mobilenet_v2(pretrained=True)
34
+ # body = create_body(mobilenet_model, n_in=n_input, cut=-2)
35
+ # net_G = DynamicUnet(body, n_output, (size, size)).to(device)
36
+ # return net_G
37
+
38
+ def create_loss_meters():
39
+ loss_D_fake = AverageMeter()
40
+ loss_D_real = AverageMeter()
41
+ loss_D = AverageMeter()
42
+ loss_G_GAN = AverageMeter()
43
+ loss_G_L1 = AverageMeter()
44
+ loss_G = AverageMeter()
45
+
46
+ return {'loss_D_fake': loss_D_fake,
47
+ 'loss_D_real': loss_D_real,
48
+ 'loss_D': loss_D,
49
+ 'loss_G_GAN': loss_G_GAN,
50
+ 'loss_G_L1': loss_G_L1,
51
+ 'loss_G': loss_G}
52
+
53
+
54
+ def update_losses(model, loss_meter_dict, count):
55
+ for loss_name, loss_meter in loss_meter_dict.items():
56
+ loss = getattr(model, loss_name)
57
+ loss_meter.update(loss.item(), count=count)
58
+
59
+
60
+ def lab_to_rgb(L, ab):
61
+ """
62
+ Takes a batch of images
63
+ """
64
+
65
+ L = (L + 1.) * 50.
66
+ ab = ab * 110.
67
+ Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
68
+ rgb_imgs = []
69
+ for img in Lab:
70
+ img_rgb = lab2rgb(img)
71
+ rgb_imgs.append(img_rgb)
72
+ return np.stack(rgb_imgs, axis=0)
73
+
74
+
75
+ def visualize(model, data, save=True):
76
+ model.net_G.eval()
77
+ with torch.no_grad():
78
+ model.setup_input(data)
79
+ model.forward()
80
+ model.net_G.train()
81
+ fake_color = model.fake_color.detach()
82
+ real_color = model.ab
83
+ L = model.L
84
+ fake_imgs = lab_to_rgb(L, fake_color)
85
+ real_imgs = lab_to_rgb(L, real_color)
86
+ fig = plt.figure(figsize=(15, 8))
87
+ for i in range(5):
88
+ ax = plt.subplot(3, 5, i + 1)
89
+ ax.imshow(L[i][0].cpu(), cmap='gray')
90
+ ax.axis("off")
91
+ ax = plt.subplot(3, 5, i + 1 + 5)
92
+ ax.imshow(fake_imgs[i])
93
+ ax.axis("off")
94
+ ax = plt.subplot(3, 5, i + 1 + 10)
95
+ ax.imshow(real_imgs[i])
96
+ ax.axis("off")
97
+ plt.show()
98
+ if save:
99
+ fig.savefig(f"colorization_{time.time()}.png")
100
+
101
+
102
+ def log_results(loss_meter_dict):
103
+ for loss_name, loss_meter in loss_meter_dict.items():
104
+ print(f"{loss_name}: {loss_meter.avg:.5f}")
weight/mobile-unet.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1fd8ca3f385d9fa334230ed58dbf5965ad1530e9e8ac3b34942c5ef1a7629f9
3
+ size 13849596
weight/mobile_pascal_final_model_weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6201bffc36a802ec3f8fa6590b92de19b4285ef5dad8d721b65cc4458e8ff5b8
3
+ size 24937539
weight/pascal_final_model_weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:145f0d59355b6b94dce6ad0636cf9de17959d79e687ba00122515111f542242e
3
+ size 135592155
weight/pascal_res18-unet.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb0b3da121e2fe68c3a32d88b8476a8cf8b81d6d8b4cecacd8a6e8fbf51e93e7
3
+ size 124508595