import torch import numpy as np from PIL import Image from torchvision import transforms from matplotlib import pyplot as plt import gradio as gr from models import MainModel, UNetAuto, Autoencoder from utils import lab_to_rgb, build_res_unet, build_mobilenet_unet # Utility to convert LAB to RGB device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Hàm load models def load_autoencoder_model(auto_model_path): unet = UNetAuto(in_channels=1, out_channels=2).to(device) model = Autoencoder(unet).to(device) model.load_state_dict(torch.load(auto_model_path, map_location=device)) model.to(device) model.eval() return model def load_model(generator_model_path, colorization_model_path, model_type='resnet'): if model_type == 'resnet': net_G = build_res_unet(n_input=1, n_output=2, size=256) elif model_type == 'mobilenet': net_G = build_mobilenet_unet(n_input=1, n_output=2, size=256) net_G.load_state_dict(torch.load(generator_model_path, map_location=device)) model = MainModel(net_G=net_G) model.load_state_dict(torch.load(colorization_model_path, map_location=device)) model.to(device) model.eval() return model resnet_model = load_model( "weight/pascal_res18-unet.pt", "weight/pascal_final_model_weights.pt", model_type='resnet' ) mobilenet_model = load_model( "weight/mobile-unet.pt", "weight/mobile_pascal_final_model_weights.pt", model_type='mobilenet' ) autoencoder_model = load_autoencoder_model("weight/autoencoder.pt") # Transformations def preprocess_image(image): image = image.resize((256, 256)) image = transforms.ToTensor()(image)[:1] * 2. - 1. return image def postprocess_image(grayscale, prediction): return lab_to_rgb(grayscale.unsqueeze(0), prediction.cpu())[0] # Prediction function with output control def colorize_image(input_image, mode): grayscale_image = Image.fromarray(input_image).convert('L') grayscale = preprocess_image(grayscale_image).to(device) with torch.no_grad(): resnet_output = resnet_model.net_G(grayscale.unsqueeze(0)) mobilenet_output = mobilenet_model.net_G(grayscale.unsqueeze(0)) autoencoder_output = autoencoder_model(grayscale.unsqueeze(0)) resnet_colorized = postprocess_image(grayscale, resnet_output) mobilenet_colorized = postprocess_image(grayscale, mobilenet_output) autoencoder_colorized = postprocess_image(grayscale, autoencoder_output) if mode == "ResNet": return resnet_colorized, None, None elif mode == "MobileNet": return None, mobilenet_colorized, None elif mode == "Autoencoder": return None, None, autoencoder_colorized elif mode == "Comparison": return resnet_colorized, mobilenet_colorized, autoencoder_colorized # Gradio Interface def gradio_interface(): with gr.Blocks() as demo: # Input components input_image = gr.Image(type="numpy", label="Upload an Image") output_modes = gr.Radio( choices=["ResNet", "MobileNet", "Autoencoder", "Comparison"], value="ResNet", label="Output Mode" ) submit_button = gr.Button("Submit") # Output components with gr.Row(): # Place output images in a single row resnet_output = gr.Image(label="Colorized Image (ResNet18)", visible=False) mobilenet_output = gr.Image(label="Colorized Image (MobileNet)", visible=False) autoencoder_output = gr.Image(label="Colorized Image (Autoencoder)", visible=False) # Output mode logic def update_visibility(mode): if mode == "ResNet": return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) elif mode == "MobileNet": return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False) elif mode == "Autoencoder": return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) elif mode == "Comparison": return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) # Dynamic event listener for output mode changes output_modes.change( fn=update_visibility, inputs=[output_modes], outputs=[resnet_output, mobilenet_output, autoencoder_output] ) # Submit logic submit_button.click( fn=colorize_image, inputs=[input_image, output_modes], outputs=[resnet_output, mobilenet_output, autoencoder_output] ) return demo # Launch if __name__ == "__main__": gradio_interface().launch()