import gradio as gr import torch from huggingface_hub import hf_hub_download from PIL import Image from torchvision import transforms from skimage.color import rgb2lab, lab2rgb import numpy as np import requests from io import BytesIO repo_id = "Hammad712/GAN-Colorization-Model" model_filename = "generator.pt" model_path = hf_hub_download(repo_id=repo_id, filename=model_filename) from fastai.vision.learner import create_body from torchvision.models import resnet34 from fastai.vision.models.unet import DynamicUnet def build_generator(n_input=1, n_output=2, size=256): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") backbone = create_body(resnet34(), pretrained=True, n_in=n_input, cut=-2) G_net = DynamicUnet(backbone, n_output, (size, size)).to(device) return G_net # Initialize and load the model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") G_net = build_generator(n_input=1, n_output=2, size=256) G_net.load_state_dict(torch.load(model_path, map_location=device)) G_net.eval() def preprocess_image(img): img = img.convert("RGB") img = transforms.Resize((256, 256), Image.BICUBIC)(img) img = np.array(img) img_to_lab = rgb2lab(img).astype("float32") img_to_lab = transforms.ToTensor()(img_to_lab) L = img_to_lab[[0], ...] / 50. - 1. return L.unsqueeze(0).to(device) def colorize_image(img, model): L = preprocess_image(img) with torch.no_grad(): ab = model(L) L = (L + 1.) * 50. ab = ab * 110. Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy() rgb_imgs = [] for img in Lab: img_rgb = lab2rgb(img) rgb_imgs.append(img_rgb) return np.stack(rgb_imgs, axis=0) def colorize(img): colorized_images = colorize_image(img, G_net) colorized_image = colorized_images[0] return Image.fromarray((colorized_image * 255).astype(np.uint8)) app = gr.Interface( fn=colorize, inputs=gr.Image(type="pil", label="Upload Grayscale Image"), outputs=gr.Image(type="pil", label="Colorized Image"), title="AI Image Colorization", description="Upload a black and white image, and the AI will colorize it.", allow_flagging="never" ) app.launch()