Jatayu / app.py
ItsJATAYU's picture
Update app.py
8ec48a4 verified
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from torchvision import transforms
from skimage.color import rgb2lab, lab2rgb
import numpy as np
import requests
from io import BytesIO
repo_id = "Hammad712/GAN-Colorization-Model"
model_filename = "generator.pt"
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
from fastai.vision.learner import create_body
from torchvision.models import resnet34
from fastai.vision.models.unet import DynamicUnet
def build_generator(n_input=1, n_output=2, size=256):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
backbone = create_body(resnet34(), pretrained=True, n_in=n_input, cut=-2)
G_net = DynamicUnet(backbone, n_output, (size, size)).to(device)
return G_net
# Initialize and load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
G_net = build_generator(n_input=1, n_output=2, size=256)
G_net.load_state_dict(torch.load(model_path, map_location=device))
G_net.eval()
def preprocess_image(img):
img = img.convert("RGB")
img = transforms.Resize((256, 256), Image.BICUBIC)(img)
img = np.array(img)
img_to_lab = rgb2lab(img).astype("float32")
img_to_lab = transforms.ToTensor()(img_to_lab)
L = img_to_lab[[0], ...] / 50. - 1.
return L.unsqueeze(0).to(device)
def colorize_image(img, model):
L = preprocess_image(img)
with torch.no_grad():
ab = model(L)
L = (L + 1.) * 50.
ab = ab * 110.
Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
rgb_imgs = []
for img in Lab:
img_rgb = lab2rgb(img)
rgb_imgs.append(img_rgb)
return np.stack(rgb_imgs, axis=0)
def colorize(img):
colorized_images = colorize_image(img, G_net)
colorized_image = colorized_images[0]
return Image.fromarray((colorized_image * 255).astype(np.uint8))
app = gr.Interface(
fn=colorize,
inputs=gr.Image(type="pil", label="Upload Grayscale Image"),
outputs=gr.Image(type="pil", label="Colorized Image"),
title="AI Image Colorization",
description="Upload a black and white image, and the AI will colorize it.",
allow_flagging="never"
)
app.launch()