|
import gradio as gr |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
from PIL import Image |
|
from torchvision import transforms |
|
from skimage.color import rgb2lab, lab2rgb |
|
import numpy as np |
|
import requests |
|
from io import BytesIO |
|
|
|
repo_id = "Hammad712/GAN-Colorization-Model" |
|
model_filename = "generator.pt" |
|
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename) |
|
|
|
from fastai.vision.learner import create_body |
|
from torchvision.models import resnet34 |
|
from fastai.vision.models.unet import DynamicUnet |
|
|
|
def build_generator(n_input=1, n_output=2, size=256): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
backbone = create_body(resnet34(), pretrained=True, n_in=n_input, cut=-2) |
|
G_net = DynamicUnet(backbone, n_output, (size, size)).to(device) |
|
return G_net |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
G_net = build_generator(n_input=1, n_output=2, size=256) |
|
G_net.load_state_dict(torch.load(model_path, map_location=device)) |
|
G_net.eval() |
|
|
|
def preprocess_image(img): |
|
img = img.convert("RGB") |
|
img = transforms.Resize((256, 256), Image.BICUBIC)(img) |
|
img = np.array(img) |
|
img_to_lab = rgb2lab(img).astype("float32") |
|
img_to_lab = transforms.ToTensor()(img_to_lab) |
|
L = img_to_lab[[0], ...] / 50. - 1. |
|
return L.unsqueeze(0).to(device) |
|
|
|
def colorize_image(img, model): |
|
L = preprocess_image(img) |
|
with torch.no_grad(): |
|
ab = model(L) |
|
L = (L + 1.) * 50. |
|
ab = ab * 110. |
|
Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy() |
|
rgb_imgs = [] |
|
for img in Lab: |
|
img_rgb = lab2rgb(img) |
|
rgb_imgs.append(img_rgb) |
|
return np.stack(rgb_imgs, axis=0) |
|
|
|
def colorize(img): |
|
colorized_images = colorize_image(img, G_net) |
|
colorized_image = colorized_images[0] |
|
return Image.fromarray((colorized_image * 255).astype(np.uint8)) |
|
|
|
app = gr.Interface( |
|
fn=colorize, |
|
inputs=gr.Image(type="pil", label="Upload Grayscale Image"), |
|
outputs=gr.Image(type="pil", label="Colorized Image"), |
|
title="AI Image Colorization", |
|
description="Upload a black and white image, and the AI will colorize it.", |
|
allow_flagging="never" |
|
) |
|
|
|
app.launch() |
|
|