Spaces:
Sleeping
Sleeping
import time | |
import numpy as np | |
from skimage.color import rgb2lab, lab2rgb | |
import matplotlib.pyplot as plt | |
from fastai.vision.learner import create_body | |
from fastai.vision.models.unet import DynamicUnet | |
from torchvision.models import resnet18 | |
from torchvision.models import mobilenet_v2 | |
import torch | |
class AverageMeter: | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.count, self.avg, self.sum = [0.] * 3 | |
def update(self, val, count=1): | |
self.count += count | |
self.sum += count * val | |
self.avg = self.sum / self.count | |
def build_res_unet(n_input=1, n_output=2, size=256): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
body = create_body(resnet18(pretrained=True), n_in=n_input, cut=-2) | |
net_G = DynamicUnet(body, n_output, (size, size)).to(device) | |
return net_G | |
def build_mobilenet_unet(n_input=1, n_output=2, size=256): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
mobilenet = mobilenet_v2(pretrained=True) | |
body = create_body(mobilenet.features, pretrained=True, n_in=n_input, cut=-2) | |
net_G = DynamicUnet(body, n_output, (size, size)).to(device) | |
return net_G | |
def create_loss_meters(): | |
loss_D_fake = AverageMeter() | |
loss_D_real = AverageMeter() | |
loss_D = AverageMeter() | |
loss_G_GAN = AverageMeter() | |
loss_G_L1 = AverageMeter() | |
loss_G = AverageMeter() | |
return {'loss_D_fake': loss_D_fake, | |
'loss_D_real': loss_D_real, | |
'loss_D': loss_D, | |
'loss_G_GAN': loss_G_GAN, | |
'loss_G_L1': loss_G_L1, | |
'loss_G': loss_G} | |
def update_losses(model, loss_meter_dict, count): | |
for loss_name, loss_meter in loss_meter_dict.items(): | |
loss = getattr(model, loss_name) | |
loss_meter.update(loss.item(), count=count) | |
def lab_to_rgb(L, ab): | |
""" | |
Takes a batch of images | |
""" | |
L = (L + 1.) * 50. | |
ab = ab * 110. | |
Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy() | |
rgb_imgs = [] | |
for img in Lab: | |
img_rgb = lab2rgb(img) | |
rgb_imgs.append(img_rgb) | |
return np.stack(rgb_imgs, axis=0) | |
def visualize(model, data, save=True): | |
model.net_G.eval() | |
with torch.no_grad(): | |
model.setup_input(data) | |
model.forward() | |
model.net_G.train() | |
fake_color = model.fake_color.detach() | |
real_color = model.ab | |
L = model.L | |
fake_imgs = lab_to_rgb(L, fake_color) | |
real_imgs = lab_to_rgb(L, real_color) | |
fig = plt.figure(figsize=(15, 8)) | |
for i in range(5): | |
ax = plt.subplot(3, 5, i + 1) | |
ax.imshow(L[i][0].cpu(), cmap='gray') | |
ax.axis("off") | |
ax = plt.subplot(3, 5, i + 1 + 5) | |
ax.imshow(fake_imgs[i]) | |
ax.axis("off") | |
ax = plt.subplot(3, 5, i + 1 + 10) | |
ax.imshow(real_imgs[i]) | |
ax.axis("off") | |
plt.show() | |
if save: | |
fig.savefig(f"colorization_{time.time()}.png") | |
def log_results(loss_meter_dict): | |
for loss_name, loss_meter in loss_meter_dict.items(): | |
print(f"{loss_name}: {loss_meter.avg:.5f}") |