Colorization / app.py
ChiKyi's picture
update
01c3f1c
raw
history blame
5.03 kB
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from matplotlib import pyplot as plt
import gradio as gr
from models import MainModel, UNetAuto, Autoencoder
from utils import lab_to_rgb, build_res_unet, build_mobilenet_unet # Utility to convert LAB to RGB
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Hàm load models
def load_unet_model(auto_model_path):
unet = UNetAuto(in_channels=1, out_channels=2).to(device)
model = Autoencoder(unet).to(device)
model.load_state_dict(torch.load(auto_model_path, map_location=device))
model.to(device)
model.eval()
return model
def load_model(generator_model_path, colorization_model_path, model_type='resnet'):
if model_type == 'resnet':
net_G = build_res_unet(n_input=1, n_output=2, size=256)
elif model_type == 'mobilenet':
net_G = build_mobilenet_unet(n_input=1, n_output=2, size=256)
net_G.load_state_dict(torch.load(generator_model_path, map_location=device))
model = MainModel(net_G=net_G)
model.load_state_dict(torch.load(colorization_model_path, map_location=device))
model.to(device)
model.eval()
return model
resnet_model = load_model(
"weight/pascal_res18-unet.pt",
"weight/pascal_final_model_weights.pt",
model_type='resnet'
)
mobilenet_model = load_model(
"weight/mobile-unet.pt",
"weight/mobile_pascal_final_model_weights.pt",
model_type='mobilenet'
)
unet_model = load_unet_model("weight/autoencoder.pt")
# Transformations
def preprocess_image(image):
image = image.resize((256, 256))
image = transforms.ToTensor()(image)[:1] * 2. - 1.
return image
def postprocess_image(grayscale, prediction, original_size):
# Convert Lab back to RGB and resize to the original image size
colorized_image = lab_to_rgb(grayscale.unsqueeze(0), prediction.cpu())[0]
colorized_image = Image.fromarray((colorized_image * 255).astype("uint8"))
return colorized_image.resize(original_size)
# Prediction function with output control
def colorize_image(input_image, mode):
grayscale_image = Image.fromarray(input_image).convert('L')
original_size = grayscale_image.size # Store original size
grayscale = preprocess_image(grayscale_image).to(device)
with torch.no_grad():
resnet_output = resnet_model.net_G(grayscale.unsqueeze(0))
mobilenet_output = mobilenet_model.net_G(grayscale.unsqueeze(0))
unet_output = unet_model(grayscale.unsqueeze(0))
# Resize outputs to match the original size
resnet_colorized = postprocess_image(grayscale, resnet_output, original_size)
mobilenet_colorized = postprocess_image(grayscale, mobilenet_output, original_size)
unet_colorized = postprocess_image(grayscale, unet_output, original_size)
if mode == "ResNet":
return resnet_colorized, None, None
elif mode == "MobileNet":
return None, mobilenet_colorized, None
elif mode == "Unet":
return None, None, unet_colorized
elif mode == "Comparison":
return resnet_colorized, mobilenet_colorized, unet_colorized
# Gradio Interface
def gradio_interface():
with gr.Blocks() as demo:
# Input components
input_image = gr.Image(type="numpy", label="Upload an Image")
output_modes = gr.Radio(
choices=["ResNet", "MobileNet", "Unet", "Comparison"],
value="ResNet",
label="Output Mode"
)
submit_button = gr.Button("Submit")
# Output components
with gr.Row(): # Place output images in a single row
resnet_output = gr.Image(label="Colorized Image (ResNet18)", visible=False)
mobilenet_output = gr.Image(label="Colorized Image (MobileNet)", visible=False)
unet_output = gr.Image(label="Colorized Image (Unet)", visible=False)
# Output mode logic
def update_visibility(mode):
if mode == "ResNet":
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
elif mode == "MobileNet":
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
elif mode == "Unet":
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
elif mode == "Comparison":
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
# Dynamic event listener for output mode changes
output_modes.change(
fn=update_visibility,
inputs=[output_modes],
outputs=[resnet_output, mobilenet_output, unet_output]
)
# Submit logic
submit_button.click(
fn=colorize_image,
inputs=[input_image, output_modes],
outputs=[resnet_output, mobilenet_output, unet_output]
)
return demo
# Launch
if __name__ == "__main__":
gradio_interface().launch()