File size: 5,029 Bytes
cf2db44
68fafaa
cf2db44
 
 
 
 
68fafaa
 
cf2db44
 
 
68fafaa
01c3f1c
68fafaa
 
 
 
 
 
cf2db44
68fafaa
 
 
 
 
cf2db44
 
 
 
 
 
 
 
 
 
68fafaa
 
cf2db44
 
68fafaa
 
 
 
 
 
01c3f1c
cf2db44
 
 
 
68fafaa
cf2db44
 
11f1268
 
 
 
 
cf2db44
68fafaa
 
 
11f1268
68fafaa
cf2db44
 
 
68fafaa
01c3f1c
cf2db44
11f1268
 
 
01c3f1c
cf2db44
68fafaa
 
 
 
01c3f1c
 
68fafaa
01c3f1c
68fafaa
cf2db44
 
68fafaa
 
 
 
 
01c3f1c
68fafaa
 
 
 
 
 
 
 
 
 
01c3f1c
68fafaa
 
 
 
 
 
 
01c3f1c
68fafaa
 
 
 
 
 
 
 
01c3f1c
68fafaa
 
 
 
 
 
01c3f1c
68fafaa
 
 
 
 
 
 
11f1268
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from matplotlib import pyplot as plt
import gradio as gr

from models import MainModel, UNetAuto, Autoencoder  
from utils import lab_to_rgb, build_res_unet, build_mobilenet_unet  # Utility to convert LAB to RGB

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hàm load models
def load_unet_model(auto_model_path):
    unet = UNetAuto(in_channels=1, out_channels=2).to(device)
    model = Autoencoder(unet).to(device)
    model.load_state_dict(torch.load(auto_model_path, map_location=device))
    model.to(device)
    model.eval()
    return model

def load_model(generator_model_path, colorization_model_path, model_type='resnet'):
    if model_type == 'resnet':
        net_G = build_res_unet(n_input=1, n_output=2, size=256)
    elif model_type == 'mobilenet':
        net_G = build_mobilenet_unet(n_input=1, n_output=2, size=256)
    
    net_G.load_state_dict(torch.load(generator_model_path, map_location=device))
    model = MainModel(net_G=net_G)
    model.load_state_dict(torch.load(colorization_model_path, map_location=device))
    model.to(device)
    model.eval()
    return model

resnet_model = load_model(
    "weight/pascal_res18-unet.pt",
    "weight/pascal_final_model_weights.pt",
    model_type='resnet'
)

mobilenet_model = load_model(
    "weight/mobile-unet.pt",
    "weight/mobile_pascal_final_model_weights.pt",
    model_type='mobilenet'
)

unet_model = load_unet_model("weight/autoencoder.pt")

# Transformations
def preprocess_image(image):
    image = image.resize((256, 256))
    image = transforms.ToTensor()(image)[:1] * 2. - 1.
    return image

def postprocess_image(grayscale, prediction, original_size):
    # Convert Lab back to RGB and resize to the original image size
    colorized_image = lab_to_rgb(grayscale.unsqueeze(0), prediction.cpu())[0]
    colorized_image = Image.fromarray((colorized_image * 255).astype("uint8"))
    return colorized_image.resize(original_size)

# Prediction function with output control
def colorize_image(input_image, mode):
    grayscale_image = Image.fromarray(input_image).convert('L')
    original_size = grayscale_image.size  # Store original size
    grayscale = preprocess_image(grayscale_image).to(device)
    
    with torch.no_grad():
        resnet_output = resnet_model.net_G(grayscale.unsqueeze(0))
        mobilenet_output = mobilenet_model.net_G(grayscale.unsqueeze(0))
        unet_output = unet_model(grayscale.unsqueeze(0))
    
    # Resize outputs to match the original size
    resnet_colorized = postprocess_image(grayscale, resnet_output, original_size)
    mobilenet_colorized = postprocess_image(grayscale, mobilenet_output, original_size)
    unet_colorized = postprocess_image(grayscale, unet_output, original_size)
    
    if mode == "ResNet":
        return resnet_colorized, None, None
    elif mode == "MobileNet":
        return None, mobilenet_colorized, None
    elif mode == "Unet":
        return None, None, unet_colorized
    elif mode == "Comparison":
        return resnet_colorized, mobilenet_colorized, unet_colorized


# Gradio Interface
def gradio_interface():
    with gr.Blocks() as demo:
        # Input components
        input_image = gr.Image(type="numpy", label="Upload an Image")
        output_modes = gr.Radio(
            choices=["ResNet", "MobileNet", "Unet", "Comparison"],
            value="ResNet",
            label="Output Mode"
        )
        
        submit_button = gr.Button("Submit")

        # Output components
        with gr.Row():  # Place output images in a single row
            resnet_output = gr.Image(label="Colorized Image (ResNet18)", visible=False)
            mobilenet_output = gr.Image(label="Colorized Image (MobileNet)", visible=False)
            unet_output = gr.Image(label="Colorized Image (Unet)", visible=False)

        # Output mode logic
        def update_visibility(mode):
            if mode == "ResNet":
                return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
            elif mode == "MobileNet":
                return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
            elif mode == "Unet":
                return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
            elif mode == "Comparison":
                return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)

        # Dynamic event listener for output mode changes
        output_modes.change(
            fn=update_visibility,
            inputs=[output_modes],
            outputs=[resnet_output, mobilenet_output, unet_output]
        )

        # Submit logic
        submit_button.click(
            fn=colorize_image,
            inputs=[input_image, output_modes],
            outputs=[resnet_output, mobilenet_output, unet_output]
        )

    return demo


# Launch
if __name__ == "__main__":
    gradio_interface().launch()