Colorization / app.py
ChiKyi's picture
update models
68fafaa
raw
history blame
4.76 kB
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from matplotlib import pyplot as plt
import gradio as gr
from models import MainModel, UNetAuto, Autoencoder
from utils import lab_to_rgb, build_res_unet, build_mobilenet_unet # Utility to convert LAB to RGB
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Hàm load models
def load_autoencoder_model(auto_model_path):
unet = UNetAuto(in_channels=1, out_channels=2).to(device)
model = Autoencoder(unet).to(device)
model.load_state_dict(torch.load(auto_model_path, map_location=device))
model.to(device)
model.eval()
return model
def load_model(generator_model_path, colorization_model_path, model_type='resnet'):
if model_type == 'resnet':
net_G = build_res_unet(n_input=1, n_output=2, size=256)
elif model_type == 'mobilenet':
net_G = build_mobilenet_unet(n_input=1, n_output=2, size=256)
net_G.load_state_dict(torch.load(generator_model_path, map_location=device))
model = MainModel(net_G=net_G)
model.load_state_dict(torch.load(colorization_model_path, map_location=device))
model.to(device)
model.eval()
return model
resnet_model = load_model(
"weight/pascal_res18-unet.pt",
"weight/pascal_final_model_weights.pt",
model_type='resnet'
)
mobilenet_model = load_model(
"weight/mobile-unet.pt",
"weight/mobile_pascal_final_model_weights.pt",
model_type='mobilenet'
)
autoencoder_model = load_autoencoder_model("weight/autoencoder.pt")
# Transformations
def preprocess_image(image):
image = image.resize((256, 256))
image = transforms.ToTensor()(image)[:1] * 2. - 1.
return image
def postprocess_image(grayscale, prediction):
return lab_to_rgb(grayscale.unsqueeze(0), prediction.cpu())[0]
# Prediction function with output control
def colorize_image(input_image, mode):
grayscale_image = Image.fromarray(input_image).convert('L')
grayscale = preprocess_image(grayscale_image).to(device)
with torch.no_grad():
resnet_output = resnet_model.net_G(grayscale.unsqueeze(0))
mobilenet_output = mobilenet_model.net_G(grayscale.unsqueeze(0))
autoencoder_output = autoencoder_model(grayscale.unsqueeze(0))
resnet_colorized = postprocess_image(grayscale, resnet_output)
mobilenet_colorized = postprocess_image(grayscale, mobilenet_output)
autoencoder_colorized = postprocess_image(grayscale, autoencoder_output)
if mode == "ResNet":
return resnet_colorized, None, None
elif mode == "MobileNet":
return None, mobilenet_colorized, None
elif mode == "Autoencoder":
return None, None, autoencoder_colorized
elif mode == "Comparison":
return resnet_colorized, mobilenet_colorized, autoencoder_colorized
# Gradio Interface
def gradio_interface():
with gr.Blocks() as demo:
# Input components
input_image = gr.Image(type="numpy", label="Upload an Image")
output_modes = gr.Radio(
choices=["ResNet", "MobileNet", "Autoencoder", "Comparison"],
value="ResNet",
label="Output Mode"
)
submit_button = gr.Button("Submit")
# Output components
with gr.Row(): # Place output images in a single row
resnet_output = gr.Image(label="Colorized Image (ResNet18)", visible=False)
mobilenet_output = gr.Image(label="Colorized Image (MobileNet)", visible=False)
autoencoder_output = gr.Image(label="Colorized Image (Autoencoder)", visible=False)
# Output mode logic
def update_visibility(mode):
if mode == "ResNet":
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
elif mode == "MobileNet":
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
elif mode == "Autoencoder":
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
elif mode == "Comparison":
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
# Dynamic event listener for output mode changes
output_modes.change(
fn=update_visibility,
inputs=[output_modes],
outputs=[resnet_output, mobilenet_output, autoencoder_output]
)
# Submit logic
submit_button.click(
fn=colorize_image,
inputs=[input_image, output_modes],
outputs=[resnet_output, mobilenet_output, autoencoder_output]
)
return demo
# Launch
if __name__ == "__main__":
gradio_interface().launch()