Spaces:
Sleeping
Sleeping
File size: 5,029 Bytes
cf2db44 68fafaa cf2db44 68fafaa cf2db44 68fafaa 01c3f1c 68fafaa cf2db44 68fafaa cf2db44 68fafaa cf2db44 68fafaa 01c3f1c cf2db44 68fafaa cf2db44 11f1268 cf2db44 68fafaa 11f1268 68fafaa cf2db44 68fafaa 01c3f1c cf2db44 11f1268 01c3f1c cf2db44 68fafaa 01c3f1c 68fafaa 01c3f1c 68fafaa cf2db44 68fafaa 01c3f1c 68fafaa 01c3f1c 68fafaa 01c3f1c 68fafaa 01c3f1c 68fafaa 01c3f1c 68fafaa 11f1268 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from matplotlib import pyplot as plt
import gradio as gr
from models import MainModel, UNetAuto, Autoencoder
from utils import lab_to_rgb, build_res_unet, build_mobilenet_unet # Utility to convert LAB to RGB
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Hàm load models
def load_unet_model(auto_model_path):
unet = UNetAuto(in_channels=1, out_channels=2).to(device)
model = Autoencoder(unet).to(device)
model.load_state_dict(torch.load(auto_model_path, map_location=device))
model.to(device)
model.eval()
return model
def load_model(generator_model_path, colorization_model_path, model_type='resnet'):
if model_type == 'resnet':
net_G = build_res_unet(n_input=1, n_output=2, size=256)
elif model_type == 'mobilenet':
net_G = build_mobilenet_unet(n_input=1, n_output=2, size=256)
net_G.load_state_dict(torch.load(generator_model_path, map_location=device))
model = MainModel(net_G=net_G)
model.load_state_dict(torch.load(colorization_model_path, map_location=device))
model.to(device)
model.eval()
return model
resnet_model = load_model(
"weight/pascal_res18-unet.pt",
"weight/pascal_final_model_weights.pt",
model_type='resnet'
)
mobilenet_model = load_model(
"weight/mobile-unet.pt",
"weight/mobile_pascal_final_model_weights.pt",
model_type='mobilenet'
)
unet_model = load_unet_model("weight/autoencoder.pt")
# Transformations
def preprocess_image(image):
image = image.resize((256, 256))
image = transforms.ToTensor()(image)[:1] * 2. - 1.
return image
def postprocess_image(grayscale, prediction, original_size):
# Convert Lab back to RGB and resize to the original image size
colorized_image = lab_to_rgb(grayscale.unsqueeze(0), prediction.cpu())[0]
colorized_image = Image.fromarray((colorized_image * 255).astype("uint8"))
return colorized_image.resize(original_size)
# Prediction function with output control
def colorize_image(input_image, mode):
grayscale_image = Image.fromarray(input_image).convert('L')
original_size = grayscale_image.size # Store original size
grayscale = preprocess_image(grayscale_image).to(device)
with torch.no_grad():
resnet_output = resnet_model.net_G(grayscale.unsqueeze(0))
mobilenet_output = mobilenet_model.net_G(grayscale.unsqueeze(0))
unet_output = unet_model(grayscale.unsqueeze(0))
# Resize outputs to match the original size
resnet_colorized = postprocess_image(grayscale, resnet_output, original_size)
mobilenet_colorized = postprocess_image(grayscale, mobilenet_output, original_size)
unet_colorized = postprocess_image(grayscale, unet_output, original_size)
if mode == "ResNet":
return resnet_colorized, None, None
elif mode == "MobileNet":
return None, mobilenet_colorized, None
elif mode == "Unet":
return None, None, unet_colorized
elif mode == "Comparison":
return resnet_colorized, mobilenet_colorized, unet_colorized
# Gradio Interface
def gradio_interface():
with gr.Blocks() as demo:
# Input components
input_image = gr.Image(type="numpy", label="Upload an Image")
output_modes = gr.Radio(
choices=["ResNet", "MobileNet", "Unet", "Comparison"],
value="ResNet",
label="Output Mode"
)
submit_button = gr.Button("Submit")
# Output components
with gr.Row(): # Place output images in a single row
resnet_output = gr.Image(label="Colorized Image (ResNet18)", visible=False)
mobilenet_output = gr.Image(label="Colorized Image (MobileNet)", visible=False)
unet_output = gr.Image(label="Colorized Image (Unet)", visible=False)
# Output mode logic
def update_visibility(mode):
if mode == "ResNet":
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
elif mode == "MobileNet":
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
elif mode == "Unet":
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
elif mode == "Comparison":
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
# Dynamic event listener for output mode changes
output_modes.change(
fn=update_visibility,
inputs=[output_modes],
outputs=[resnet_output, mobilenet_output, unet_output]
)
# Submit logic
submit_button.click(
fn=colorize_image,
inputs=[input_image, output_modes],
outputs=[resnet_output, mobilenet_output, unet_output]
)
return demo
# Launch
if __name__ == "__main__":
gradio_interface().launch()
|