Spaces:
Sleeping
Sleeping
update models
Browse files- app.py +93 -50
- models.py +121 -1
- utils.py +6 -6
- weight/autoencoder.pt +3 -0
app.py
CHANGED
@@ -1,90 +1,133 @@
|
|
1 |
import torch
|
|
|
2 |
from PIL import Image
|
3 |
from torchvision import transforms
|
4 |
from matplotlib import pyplot as plt
|
5 |
import gradio as gr
|
6 |
|
7 |
-
from models import MainModel
|
8 |
-
from utils import lab_to_rgb, build_res_unet
|
9 |
|
10 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
def load_model(generator_model_path, colorization_model_path
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
# net_G = build_mobile_unet(n_input=1, n_output=2, size=256)
|
19 |
|
20 |
net_G.load_state_dict(torch.load(generator_model_path, map_location=device))
|
21 |
-
|
22 |
-
# Create MainModel and load weights
|
23 |
model = MainModel(net_G=net_G)
|
24 |
model.load_state_dict(torch.load(colorization_model_path, map_location=device))
|
25 |
-
|
26 |
-
# Move model to device and set to eval mode
|
27 |
model.to(device)
|
28 |
model.eval()
|
29 |
-
|
30 |
return model
|
31 |
|
32 |
-
# Load pretrained models
|
33 |
resnet_model = load_model(
|
34 |
"weight/pascal_res18-unet.pt",
|
35 |
-
"weight/pascal_final_model_weights.pt"
|
36 |
-
|
37 |
)
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
44 |
|
45 |
# Transformations
|
46 |
def preprocess_image(image):
|
47 |
image = image.resize((256, 256))
|
48 |
-
image = transforms.ToTensor()(image)[:1] * 2. - 1.
|
49 |
return image
|
50 |
|
51 |
def postprocess_image(grayscale, prediction):
|
52 |
return lab_to_rgb(grayscale.unsqueeze(0), prediction.cpu())[0]
|
53 |
|
54 |
-
# Prediction function
|
55 |
-
def colorize_image(input_image):
|
56 |
-
|
57 |
-
|
58 |
-
grayscale = preprocess_image(input_image).to(device)
|
59 |
|
60 |
-
# Generate predictions
|
61 |
with torch.no_grad():
|
62 |
resnet_output = resnet_model.net_G(grayscale.unsqueeze(0))
|
63 |
-
|
|
|
64 |
|
65 |
-
# Post-process results
|
66 |
resnet_colorized = postprocess_image(grayscale, resnet_output)
|
67 |
-
|
|
|
68 |
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
74 |
|
75 |
# Gradio Interface
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
gr.
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
-
# Launch
|
89 |
-
if __name__ ==
|
90 |
-
|
|
|
1 |
import torch
|
2 |
+
import numpy as np
|
3 |
from PIL import Image
|
4 |
from torchvision import transforms
|
5 |
from matplotlib import pyplot as plt
|
6 |
import gradio as gr
|
7 |
|
8 |
+
from models import MainModel, UNetAuto, Autoencoder
|
9 |
+
from utils import lab_to_rgb, build_res_unet, build_mobilenet_unet # Utility to convert LAB to RGB
|
10 |
|
11 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
12 |
|
13 |
+
# Hàm load models
|
14 |
+
def load_autoencoder_model(auto_model_path):
|
15 |
+
unet = UNetAuto(in_channels=1, out_channels=2).to(device)
|
16 |
+
model = Autoencoder(unet).to(device)
|
17 |
+
model.load_state_dict(torch.load(auto_model_path, map_location=device))
|
18 |
+
model.to(device)
|
19 |
+
model.eval()
|
20 |
+
return model
|
21 |
|
22 |
+
def load_model(generator_model_path, colorization_model_path, model_type='resnet'):
|
23 |
+
if model_type == 'resnet':
|
24 |
+
net_G = build_res_unet(n_input=1, n_output=2, size=256)
|
25 |
+
elif model_type == 'mobilenet':
|
26 |
+
net_G = build_mobilenet_unet(n_input=1, n_output=2, size=256)
|
|
|
27 |
|
28 |
net_G.load_state_dict(torch.load(generator_model_path, map_location=device))
|
|
|
|
|
29 |
model = MainModel(net_G=net_G)
|
30 |
model.load_state_dict(torch.load(colorization_model_path, map_location=device))
|
|
|
|
|
31 |
model.to(device)
|
32 |
model.eval()
|
|
|
33 |
return model
|
34 |
|
|
|
35 |
resnet_model = load_model(
|
36 |
"weight/pascal_res18-unet.pt",
|
37 |
+
"weight/pascal_final_model_weights.pt",
|
38 |
+
model_type='resnet'
|
39 |
)
|
40 |
|
41 |
+
mobilenet_model = load_model(
|
42 |
+
"weight/mobile-unet.pt",
|
43 |
+
"weight/mobile_pascal_final_model_weights.pt",
|
44 |
+
model_type='mobilenet'
|
45 |
+
)
|
46 |
+
|
47 |
+
autoencoder_model = load_autoencoder_model("weight/autoencoder.pt")
|
48 |
|
49 |
# Transformations
|
50 |
def preprocess_image(image):
|
51 |
image = image.resize((256, 256))
|
52 |
+
image = transforms.ToTensor()(image)[:1] * 2. - 1.
|
53 |
return image
|
54 |
|
55 |
def postprocess_image(grayscale, prediction):
|
56 |
return lab_to_rgb(grayscale.unsqueeze(0), prediction.cpu())[0]
|
57 |
|
58 |
+
# Prediction function with output control
|
59 |
+
def colorize_image(input_image, mode):
|
60 |
+
grayscale_image = Image.fromarray(input_image).convert('L')
|
61 |
+
grayscale = preprocess_image(grayscale_image).to(device)
|
|
|
62 |
|
|
|
63 |
with torch.no_grad():
|
64 |
resnet_output = resnet_model.net_G(grayscale.unsqueeze(0))
|
65 |
+
mobilenet_output = mobilenet_model.net_G(grayscale.unsqueeze(0))
|
66 |
+
autoencoder_output = autoencoder_model(grayscale.unsqueeze(0))
|
67 |
|
|
|
68 |
resnet_colorized = postprocess_image(grayscale, resnet_output)
|
69 |
+
mobilenet_colorized = postprocess_image(grayscale, mobilenet_output)
|
70 |
+
autoencoder_colorized = postprocess_image(grayscale, autoencoder_output)
|
71 |
|
72 |
+
if mode == "ResNet":
|
73 |
+
return resnet_colorized, None, None
|
74 |
+
elif mode == "MobileNet":
|
75 |
+
return None, mobilenet_colorized, None
|
76 |
+
elif mode == "Autoencoder":
|
77 |
+
return None, None, autoencoder_colorized
|
78 |
+
elif mode == "Comparison":
|
79 |
+
return resnet_colorized, mobilenet_colorized, autoencoder_colorized
|
80 |
+
|
81 |
|
82 |
# Gradio Interface
|
83 |
+
def gradio_interface():
|
84 |
+
with gr.Blocks() as demo:
|
85 |
+
# Input components
|
86 |
+
input_image = gr.Image(type="numpy", label="Upload an Image")
|
87 |
+
output_modes = gr.Radio(
|
88 |
+
choices=["ResNet", "MobileNet", "Autoencoder", "Comparison"],
|
89 |
+
value="ResNet",
|
90 |
+
label="Output Mode"
|
91 |
+
)
|
92 |
+
|
93 |
+
submit_button = gr.Button("Submit")
|
94 |
+
|
95 |
+
# Output components
|
96 |
+
with gr.Row(): # Place output images in a single row
|
97 |
+
resnet_output = gr.Image(label="Colorized Image (ResNet18)", visible=False)
|
98 |
+
mobilenet_output = gr.Image(label="Colorized Image (MobileNet)", visible=False)
|
99 |
+
autoencoder_output = gr.Image(label="Colorized Image (Autoencoder)", visible=False)
|
100 |
+
|
101 |
+
# Output mode logic
|
102 |
+
def update_visibility(mode):
|
103 |
+
if mode == "ResNet":
|
104 |
+
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
|
105 |
+
elif mode == "MobileNet":
|
106 |
+
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
|
107 |
+
elif mode == "Autoencoder":
|
108 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
|
109 |
+
elif mode == "Comparison":
|
110 |
+
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
|
111 |
+
|
112 |
+
# Dynamic event listener for output mode changes
|
113 |
+
output_modes.change(
|
114 |
+
fn=update_visibility,
|
115 |
+
inputs=[output_modes],
|
116 |
+
outputs=[resnet_output, mobilenet_output, autoencoder_output]
|
117 |
+
)
|
118 |
+
|
119 |
+
# Submit logic
|
120 |
+
|
121 |
+
submit_button.click(
|
122 |
+
fn=colorize_image,
|
123 |
+
inputs=[input_image, output_modes],
|
124 |
+
outputs=[resnet_output, mobilenet_output, autoencoder_output]
|
125 |
+
)
|
126 |
+
|
127 |
+
return demo
|
128 |
+
|
129 |
+
|
130 |
|
131 |
+
# Launch
|
132 |
+
if __name__ == "__main__":
|
133 |
+
gradio_interface().launch()
|
models.py
CHANGED
@@ -171,4 +171,124 @@ class MainModel(nn.Module):
|
|
171 |
self.set_requires_grad(self.net_D, False)
|
172 |
self.opt_G.zero_grad()
|
173 |
self.backward_G()
|
174 |
-
self.opt_G.step()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
self.set_requires_grad(self.net_D, False)
|
172 |
self.opt_G.zero_grad()
|
173 |
self.backward_G()
|
174 |
+
self.opt_G.step()
|
175 |
+
|
176 |
+
|
177 |
+
class UNetAuto(nn.Module):
|
178 |
+
|
179 |
+
def __init__(self, in_channels=1, out_channels=2, features=[64, 128, 256, 512]):
|
180 |
+
|
181 |
+
super(UNetAuto, self).__init__()
|
182 |
+
|
183 |
+
self.encoder = nn.ModuleList()
|
184 |
+
|
185 |
+
self.decoder = nn.ModuleList()
|
186 |
+
|
187 |
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
188 |
+
|
189 |
+
|
190 |
+
|
191 |
+
# Encoder part
|
192 |
+
|
193 |
+
for feature in features:
|
194 |
+
|
195 |
+
self.encoder.append(self._block(in_channels, feature))
|
196 |
+
|
197 |
+
in_channels = feature
|
198 |
+
|
199 |
+
|
200 |
+
|
201 |
+
# Decoder part (Upsampling)
|
202 |
+
|
203 |
+
for feature in reversed(features):
|
204 |
+
|
205 |
+
self.decoder.append(nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2))
|
206 |
+
|
207 |
+
self.decoder.append(self._block(feature * 2, feature))
|
208 |
+
|
209 |
+
|
210 |
+
|
211 |
+
# Final Convolution
|
212 |
+
|
213 |
+
self.bottleneck = self._block(features[-1], features[-1] * 2)
|
214 |
+
|
215 |
+
self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
|
216 |
+
|
217 |
+
|
218 |
+
|
219 |
+
def forward(self, x): #, t):
|
220 |
+
|
221 |
+
skip_connections = []
|
222 |
+
|
223 |
+
|
224 |
+
|
225 |
+
# Encode
|
226 |
+
|
227 |
+
for layer in self.encoder:
|
228 |
+
|
229 |
+
x = layer(x)
|
230 |
+
|
231 |
+
skip_connections.append(x)
|
232 |
+
|
233 |
+
x = self.pool(x)
|
234 |
+
|
235 |
+
|
236 |
+
|
237 |
+
# Bottleneck
|
238 |
+
|
239 |
+
x = self.bottleneck(x)
|
240 |
+
|
241 |
+
|
242 |
+
|
243 |
+
# Decode
|
244 |
+
|
245 |
+
skip_connections = skip_connections[::-1]
|
246 |
+
|
247 |
+
for idx in range(0, len(self.decoder), 2):
|
248 |
+
|
249 |
+
x = self.decoder[idx](x)
|
250 |
+
|
251 |
+
skip_connection = skip_connections[idx // 2]
|
252 |
+
|
253 |
+
x = torch.cat((x, skip_connection), dim=1) # Skip connection
|
254 |
+
|
255 |
+
x = self.decoder[idx + 1](x)
|
256 |
+
|
257 |
+
|
258 |
+
|
259 |
+
return self.final_conv(x)
|
260 |
+
|
261 |
+
|
262 |
+
|
263 |
+
def _block(self, in_channels, out_channels):
|
264 |
+
|
265 |
+
return nn.Sequential(
|
266 |
+
|
267 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
|
268 |
+
|
269 |
+
nn.BatchNorm2d(out_channels),
|
270 |
+
|
271 |
+
nn.ReLU(inplace=True),
|
272 |
+
|
273 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
|
274 |
+
|
275 |
+
nn.BatchNorm2d(out_channels),
|
276 |
+
|
277 |
+
nn.ReLU(inplace=True),
|
278 |
+
|
279 |
+
)
|
280 |
+
|
281 |
+
|
282 |
+
class Autoencoder(nn.Module):
|
283 |
+
|
284 |
+
def __init__(self, model):
|
285 |
+
|
286 |
+
super(Autoencoder, self).__init__()
|
287 |
+
|
288 |
+
self.model = model
|
289 |
+
|
290 |
+
|
291 |
+
|
292 |
+
def forward(self, x): #, t):
|
293 |
+
|
294 |
+
return self.model(x)#, t)
|
utils.py
CHANGED
@@ -28,12 +28,12 @@ def build_res_unet(n_input=1, n_output=2, size=256):
|
|
28 |
net_G = DynamicUnet(body, n_output, (size, size)).to(device)
|
29 |
return net_G
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
|
38 |
def create_loss_meters():
|
39 |
loss_D_fake = AverageMeter()
|
|
|
28 |
net_G = DynamicUnet(body, n_output, (size, size)).to(device)
|
29 |
return net_G
|
30 |
|
31 |
+
def build_mobilenet_unet(n_input=1, n_output=2, size=256):
|
32 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
33 |
+
mobilenet = mobilenet_v2(pretrained=True)
|
34 |
+
body = create_body(mobilenet.features, pretrained=True, n_in=n_input, cut=-2)
|
35 |
+
net_G = DynamicUnet(body, n_output, (size, size)).to(device)
|
36 |
+
return net_G
|
37 |
|
38 |
def create_loss_meters():
|
39 |
loss_D_fake = AverageMeter()
|
weight/autoencoder.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a4231828be0fe2bb7f9701e809917661da56fd2f58a9f19728da0f936f4c2880
|
3 |
+
size 124234454
|