Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
from PIL import Image | |
from torchvision import transforms | |
from matplotlib import pyplot as plt | |
import gradio as gr | |
from models import MainModel, UNetAuto, Autoencoder | |
from utils import lab_to_rgb, build_res_unet, build_mobilenet_unet # Utility to convert LAB to RGB | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Hàm load models | |
def load_unet_model(auto_model_path): | |
unet = UNetAuto(in_channels=1, out_channels=2).to(device) | |
model = Autoencoder(unet).to(device) | |
model.load_state_dict(torch.load(auto_model_path, map_location=device)) | |
model.to(device) | |
model.eval() | |
return model | |
def load_model(generator_model_path, colorization_model_path, model_type='resnet'): | |
if model_type == 'resnet': | |
net_G = build_res_unet(n_input=1, n_output=2, size=256) | |
elif model_type == 'mobilenet': | |
net_G = build_mobilenet_unet(n_input=1, n_output=2, size=256) | |
net_G.load_state_dict(torch.load(generator_model_path, map_location=device)) | |
model = MainModel(net_G=net_G) | |
model.load_state_dict(torch.load(colorization_model_path, map_location=device)) | |
model.to(device) | |
model.eval() | |
return model | |
resnet_model = load_model( | |
"weight/pascal_res18-unet.pt", | |
"weight/pascal_final_model_weights.pt", | |
model_type='resnet' | |
) | |
mobilenet_model = load_model( | |
"weight/mobile-unet.pt", | |
"weight/mobile_pascal_final_model_weights.pt", | |
model_type='mobilenet' | |
) | |
unet_model = load_unet_model("weight/autoencoder.pt") | |
# Transformations | |
def preprocess_image(image): | |
image = image.resize((256, 256)) | |
image = transforms.ToTensor()(image)[:1] * 2. - 1. | |
return image | |
def postprocess_image(grayscale, prediction, original_size): | |
# Convert Lab back to RGB and resize to the original image size | |
colorized_image = lab_to_rgb(grayscale.unsqueeze(0), prediction.cpu())[0] | |
colorized_image = Image.fromarray((colorized_image * 255).astype("uint8")) | |
return colorized_image.resize(original_size) | |
# Prediction function with output control | |
def colorize_image(input_image, mode): | |
grayscale_image = Image.fromarray(input_image).convert('L') | |
original_size = grayscale_image.size # Store original size | |
grayscale = preprocess_image(grayscale_image).to(device) | |
with torch.no_grad(): | |
resnet_output = resnet_model.net_G(grayscale.unsqueeze(0)) | |
mobilenet_output = mobilenet_model.net_G(grayscale.unsqueeze(0)) | |
unet_output = unet_model(grayscale.unsqueeze(0)) | |
# Resize outputs to match the original size | |
resnet_colorized = postprocess_image(grayscale, resnet_output, original_size) | |
mobilenet_colorized = postprocess_image(grayscale, mobilenet_output, original_size) | |
unet_colorized = postprocess_image(grayscale, unet_output, original_size) | |
if mode == "ResNet": | |
return resnet_colorized, None, None | |
elif mode == "MobileNet": | |
return None, mobilenet_colorized, None | |
elif mode == "Unet": | |
return None, None, unet_colorized | |
elif mode == "Comparison": | |
return resnet_colorized, mobilenet_colorized, unet_colorized | |
# Gradio Interface | |
def gradio_interface(): | |
with gr.Blocks() as demo: | |
# Input components | |
input_image = gr.Image(type="numpy", label="Upload an Image") | |
output_modes = gr.Radio( | |
choices=["ResNet", "MobileNet", "Unet", "Comparison"], | |
value="ResNet", | |
label="Output Mode" | |
) | |
submit_button = gr.Button("Submit") | |
# Output components | |
with gr.Row(): # Place output images in a single row | |
resnet_output = gr.Image(label="Colorized Image (ResNet18)", visible=False) | |
mobilenet_output = gr.Image(label="Colorized Image (MobileNet)", visible=False) | |
unet_output = gr.Image(label="Colorized Image (Unet)", visible=False) | |
# Output mode logic | |
def update_visibility(mode): | |
if mode == "ResNet": | |
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) | |
elif mode == "MobileNet": | |
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False) | |
elif mode == "Unet": | |
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) | |
elif mode == "Comparison": | |
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) | |
# Dynamic event listener for output mode changes | |
output_modes.change( | |
fn=update_visibility, | |
inputs=[output_modes], | |
outputs=[resnet_output, mobilenet_output, unet_output] | |
) | |
# Submit logic | |
submit_button.click( | |
fn=colorize_image, | |
inputs=[input_image, output_modes], | |
outputs=[resnet_output, mobilenet_output, unet_output] | |
) | |
return demo | |
# Launch | |
if __name__ == "__main__": | |
gradio_interface().launch() | |