import torch import numpy as np from PIL import Image from torchvision import transforms from matplotlib import pyplot as plt import gradio as gr from models import MainModel, UNetAuto, Autoencoder from utils import lab_to_rgb, build_res_unet, build_mobilenet_unet # Utility to convert LAB to RGB device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Hàm load models def load_unet_model(auto_model_path): unet = UNetAuto(in_channels=1, out_channels=2).to(device) model = Autoencoder(unet).to(device) model.load_state_dict(torch.load(auto_model_path, map_location=device)) model.to(device) model.eval() return model def load_model(generator_model_path, colorization_model_path, model_type='resnet'): if model_type == 'resnet': net_G = build_res_unet(n_input=1, n_output=2, size=256) elif model_type == 'mobilenet': net_G = build_mobilenet_unet(n_input=1, n_output=2, size=256) net_G.load_state_dict(torch.load(generator_model_path, map_location=device)) model = MainModel(net_G=net_G) model.load_state_dict(torch.load(colorization_model_path, map_location=device)) model.to(device) model.eval() return model resnet_model = load_model( "weight/pascal_res18-unet.pt", "weight/pascal_final_model_weights.pt", model_type='resnet' ) mobilenet_model = load_model( "weight/mobile-unet.pt", "weight/mobile_pascal_final_model_weights.pt", model_type='mobilenet' ) unet_model = load_unet_model("weight/autoencoder.pt") # Transformations def preprocess_image(image): image = image.resize((256, 256)) image = transforms.ToTensor()(image)[:1] * 2. - 1. return image def postprocess_image(grayscale, prediction, original_size): # Convert Lab back to RGB and resize to the original image size colorized_image = lab_to_rgb(grayscale.unsqueeze(0), prediction.cpu())[0] colorized_image = Image.fromarray((colorized_image * 255).astype("uint8")) return colorized_image.resize(original_size) # Prediction function with output control def colorize_image(input_image, mode): grayscale_image = Image.fromarray(input_image).convert('L') original_size = grayscale_image.size # Store original size grayscale = preprocess_image(grayscale_image).to(device) with torch.no_grad(): resnet_output = resnet_model.net_G(grayscale.unsqueeze(0)) mobilenet_output = mobilenet_model.net_G(grayscale.unsqueeze(0)) unet_output = unet_model(grayscale.unsqueeze(0)) # Resize outputs to match the original size resnet_colorized = postprocess_image(grayscale, resnet_output, original_size) mobilenet_colorized = postprocess_image(grayscale, mobilenet_output, original_size) unet_colorized = postprocess_image(grayscale, unet_output, original_size) if mode == "ResNet": return resnet_colorized, None, None elif mode == "MobileNet": return None, mobilenet_colorized, None elif mode == "Unet": return None, None, unet_colorized elif mode == "Comparison": return resnet_colorized, mobilenet_colorized, unet_colorized # Gradio Interface def gradio_interface(): with gr.Blocks() as demo: # Input components input_image = gr.Image(type="numpy", label="Upload an Image") output_modes = gr.Radio( choices=["ResNet", "MobileNet", "Unet", "Comparison"], value="ResNet", label="Output Mode" ) submit_button = gr.Button("Submit") # Output components with gr.Row(): # Place output images in a single row resnet_output = gr.Image(label="Colorized Image (ResNet18)", visible=False) mobilenet_output = gr.Image(label="Colorized Image (MobileNet)", visible=False) unet_output = gr.Image(label="Colorized Image (Unet)", visible=False) # Output mode logic def update_visibility(mode): if mode == "ResNet": return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) elif mode == "MobileNet": return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False) elif mode == "Unet": return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) elif mode == "Comparison": return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) # Dynamic event listener for output mode changes output_modes.change( fn=update_visibility, inputs=[output_modes], outputs=[resnet_output, mobilenet_output, unet_output] ) # Submit logic submit_button.click( fn=colorize_image, inputs=[input_image, output_modes], outputs=[resnet_output, mobilenet_output, unet_output] ) return demo # Launch if __name__ == "__main__": gradio_interface().launch()