import time import numpy as np from skimage.color import rgb2lab, lab2rgb import matplotlib.pyplot as plt from fastai.vision.learner import create_body from fastai.vision.models.unet import DynamicUnet from torchvision.models import resnet18 from torchvision.models import mobilenet_v2 import torch class AverageMeter: def __init__(self): self.reset() def reset(self): self.count, self.avg, self.sum = [0.] * 3 def update(self, val, count=1): self.count += count self.sum += count * val self.avg = self.sum / self.count def build_res_unet(n_input=1, n_output=2, size=256): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") body = create_body(resnet18(pretrained=True), n_in=n_input, cut=-2) net_G = DynamicUnet(body, n_output, (size, size)).to(device) return net_G def build_mobilenet_unet(n_input=1, n_output=2, size=256): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") mobilenet = mobilenet_v2(pretrained=True) body = create_body(mobilenet.features, pretrained=True, n_in=n_input, cut=-2) net_G = DynamicUnet(body, n_output, (size, size)).to(device) return net_G def create_loss_meters(): loss_D_fake = AverageMeter() loss_D_real = AverageMeter() loss_D = AverageMeter() loss_G_GAN = AverageMeter() loss_G_L1 = AverageMeter() loss_G = AverageMeter() return {'loss_D_fake': loss_D_fake, 'loss_D_real': loss_D_real, 'loss_D': loss_D, 'loss_G_GAN': loss_G_GAN, 'loss_G_L1': loss_G_L1, 'loss_G': loss_G} def update_losses(model, loss_meter_dict, count): for loss_name, loss_meter in loss_meter_dict.items(): loss = getattr(model, loss_name) loss_meter.update(loss.item(), count=count) def lab_to_rgb(L, ab): """ Takes a batch of images """ L = (L + 1.) * 50. ab = ab * 110. Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy() rgb_imgs = [] for img in Lab: img_rgb = lab2rgb(img) rgb_imgs.append(img_rgb) return np.stack(rgb_imgs, axis=0) def visualize(model, data, save=True): model.net_G.eval() with torch.no_grad(): model.setup_input(data) model.forward() model.net_G.train() fake_color = model.fake_color.detach() real_color = model.ab L = model.L fake_imgs = lab_to_rgb(L, fake_color) real_imgs = lab_to_rgb(L, real_color) fig = plt.figure(figsize=(15, 8)) for i in range(5): ax = plt.subplot(3, 5, i + 1) ax.imshow(L[i][0].cpu(), cmap='gray') ax.axis("off") ax = plt.subplot(3, 5, i + 1 + 5) ax.imshow(fake_imgs[i]) ax.axis("off") ax = plt.subplot(3, 5, i + 1 + 10) ax.imshow(real_imgs[i]) ax.axis("off") plt.show() if save: fig.savefig(f"colorization_{time.time()}.png") def log_results(loss_meter_dict): for loss_name, loss_meter in loss_meter_dict.items(): print(f"{loss_name}: {loss_meter.avg:.5f}")