ChiKyi commited on
Commit
68fafaa
1 Parent(s): 4d92358

update models

Browse files
Files changed (4) hide show
  1. app.py +93 -50
  2. models.py +121 -1
  3. utils.py +6 -6
  4. weight/autoencoder.pt +3 -0
app.py CHANGED
@@ -1,90 +1,133 @@
1
  import torch
 
2
  from PIL import Image
3
  from torchvision import transforms
4
  from matplotlib import pyplot as plt
5
  import gradio as gr
6
 
7
- from models import MainModel # Import class for your main model
8
- from utils import lab_to_rgb, build_res_unet#, build_mobile_unet # Utility to convert LAB to RGB
9
 
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
 
 
 
 
 
 
 
 
 
12
 
13
- def load_model(generator_model_path, colorization_model_path): #, model_type='resnet')
14
-
15
- #if model_type == 'resnet':
16
- net_G = build_res_unet(n_input=1, n_output=2, size=256)
17
- # elif model_type == 'mobilenet':
18
- # net_G = build_mobile_unet(n_input=1, n_output=2, size=256)
19
 
20
  net_G.load_state_dict(torch.load(generator_model_path, map_location=device))
21
-
22
- # Create MainModel and load weights
23
  model = MainModel(net_G=net_G)
24
  model.load_state_dict(torch.load(colorization_model_path, map_location=device))
25
-
26
- # Move model to device and set to eval mode
27
  model.to(device)
28
  model.eval()
29
-
30
  return model
31
 
32
- # Load pretrained models
33
  resnet_model = load_model(
34
  "weight/pascal_res18-unet.pt",
35
- "weight/pascal_final_model_weights.pt"
36
- # model_type='resnet'
37
  )
38
 
39
- # mobilenet_model = load_model(
40
- # "weight/mobile-unet.pt",
41
- # "weight/mobile_pascal_final_model_weights.pt",
42
- # model_type='mobilenet'
43
- # )
 
 
44
 
45
  # Transformations
46
  def preprocess_image(image):
47
  image = image.resize((256, 256))
48
- image = transforms.ToTensor()(image)[:1] * 2. - 1. # Normalize to [-1, 1]
49
  return image
50
 
51
  def postprocess_image(grayscale, prediction):
52
  return lab_to_rgb(grayscale.unsqueeze(0), prediction.cpu())[0]
53
 
54
- # Prediction function
55
- def colorize_image(input_image):
56
- # Convert input to grayscale
57
- input_image = Image.fromarray(input_image).convert('L')
58
- grayscale = preprocess_image(input_image).to(device)
59
 
60
- # Generate predictions
61
  with torch.no_grad():
62
  resnet_output = resnet_model.net_G(grayscale.unsqueeze(0))
63
- # mobilenet_output = mobilenet_model.net_G(grayscale.unsqueeze(0))
 
64
 
65
- # Post-process results
66
  resnet_colorized = postprocess_image(grayscale, resnet_output)
67
- # mobilenet_colorized = postprocess_image(grayscale, mobilenet_output)
 
68
 
69
- return (
70
- input_image, # Grayscale image
71
- resnet_colorized # ResNet18 colorized image
72
- # mobilenet_colorized # MobileNet colorized image
73
- )
 
 
 
 
74
 
75
  # Gradio Interface
76
- interface = gr.Interface(
77
- fn=colorize_image,
78
- inputs=gr.Image(type="numpy", label="Upload a Color Image"),
79
- outputs=[
80
- gr.Image(label="Grayscale Image"),
81
- gr.Image(label="Colorized Image (ResNet18)")
82
- # gr.Image(label="Colorized Image (MobileNet)")
83
- ],
84
- title="Image Colorization",
85
- description="Upload a color image"
86
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- # Launch Gradio app
89
- if __name__ == '__main__':
90
- interface.launch()
 
1
  import torch
2
+ import numpy as np
3
  from PIL import Image
4
  from torchvision import transforms
5
  from matplotlib import pyplot as plt
6
  import gradio as gr
7
 
8
+ from models import MainModel, UNetAuto, Autoencoder
9
+ from utils import lab_to_rgb, build_res_unet, build_mobilenet_unet # Utility to convert LAB to RGB
10
 
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
 
13
+ # Hàm load models
14
+ def load_autoencoder_model(auto_model_path):
15
+ unet = UNetAuto(in_channels=1, out_channels=2).to(device)
16
+ model = Autoencoder(unet).to(device)
17
+ model.load_state_dict(torch.load(auto_model_path, map_location=device))
18
+ model.to(device)
19
+ model.eval()
20
+ return model
21
 
22
+ def load_model(generator_model_path, colorization_model_path, model_type='resnet'):
23
+ if model_type == 'resnet':
24
+ net_G = build_res_unet(n_input=1, n_output=2, size=256)
25
+ elif model_type == 'mobilenet':
26
+ net_G = build_mobilenet_unet(n_input=1, n_output=2, size=256)
 
27
 
28
  net_G.load_state_dict(torch.load(generator_model_path, map_location=device))
 
 
29
  model = MainModel(net_G=net_G)
30
  model.load_state_dict(torch.load(colorization_model_path, map_location=device))
 
 
31
  model.to(device)
32
  model.eval()
 
33
  return model
34
 
 
35
  resnet_model = load_model(
36
  "weight/pascal_res18-unet.pt",
37
+ "weight/pascal_final_model_weights.pt",
38
+ model_type='resnet'
39
  )
40
 
41
+ mobilenet_model = load_model(
42
+ "weight/mobile-unet.pt",
43
+ "weight/mobile_pascal_final_model_weights.pt",
44
+ model_type='mobilenet'
45
+ )
46
+
47
+ autoencoder_model = load_autoencoder_model("weight/autoencoder.pt")
48
 
49
  # Transformations
50
  def preprocess_image(image):
51
  image = image.resize((256, 256))
52
+ image = transforms.ToTensor()(image)[:1] * 2. - 1.
53
  return image
54
 
55
  def postprocess_image(grayscale, prediction):
56
  return lab_to_rgb(grayscale.unsqueeze(0), prediction.cpu())[0]
57
 
58
+ # Prediction function with output control
59
+ def colorize_image(input_image, mode):
60
+ grayscale_image = Image.fromarray(input_image).convert('L')
61
+ grayscale = preprocess_image(grayscale_image).to(device)
 
62
 
 
63
  with torch.no_grad():
64
  resnet_output = resnet_model.net_G(grayscale.unsqueeze(0))
65
+ mobilenet_output = mobilenet_model.net_G(grayscale.unsqueeze(0))
66
+ autoencoder_output = autoencoder_model(grayscale.unsqueeze(0))
67
 
 
68
  resnet_colorized = postprocess_image(grayscale, resnet_output)
69
+ mobilenet_colorized = postprocess_image(grayscale, mobilenet_output)
70
+ autoencoder_colorized = postprocess_image(grayscale, autoencoder_output)
71
 
72
+ if mode == "ResNet":
73
+ return resnet_colorized, None, None
74
+ elif mode == "MobileNet":
75
+ return None, mobilenet_colorized, None
76
+ elif mode == "Autoencoder":
77
+ return None, None, autoencoder_colorized
78
+ elif mode == "Comparison":
79
+ return resnet_colorized, mobilenet_colorized, autoencoder_colorized
80
+
81
 
82
  # Gradio Interface
83
+ def gradio_interface():
84
+ with gr.Blocks() as demo:
85
+ # Input components
86
+ input_image = gr.Image(type="numpy", label="Upload an Image")
87
+ output_modes = gr.Radio(
88
+ choices=["ResNet", "MobileNet", "Autoencoder", "Comparison"],
89
+ value="ResNet",
90
+ label="Output Mode"
91
+ )
92
+
93
+ submit_button = gr.Button("Submit")
94
+
95
+ # Output components
96
+ with gr.Row(): # Place output images in a single row
97
+ resnet_output = gr.Image(label="Colorized Image (ResNet18)", visible=False)
98
+ mobilenet_output = gr.Image(label="Colorized Image (MobileNet)", visible=False)
99
+ autoencoder_output = gr.Image(label="Colorized Image (Autoencoder)", visible=False)
100
+
101
+ # Output mode logic
102
+ def update_visibility(mode):
103
+ if mode == "ResNet":
104
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
105
+ elif mode == "MobileNet":
106
+ return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
107
+ elif mode == "Autoencoder":
108
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
109
+ elif mode == "Comparison":
110
+ return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
111
+
112
+ # Dynamic event listener for output mode changes
113
+ output_modes.change(
114
+ fn=update_visibility,
115
+ inputs=[output_modes],
116
+ outputs=[resnet_output, mobilenet_output, autoencoder_output]
117
+ )
118
+
119
+ # Submit logic
120
+
121
+ submit_button.click(
122
+ fn=colorize_image,
123
+ inputs=[input_image, output_modes],
124
+ outputs=[resnet_output, mobilenet_output, autoencoder_output]
125
+ )
126
+
127
+ return demo
128
+
129
+
130
 
131
+ # Launch
132
+ if __name__ == "__main__":
133
+ gradio_interface().launch()
models.py CHANGED
@@ -171,4 +171,124 @@ class MainModel(nn.Module):
171
  self.set_requires_grad(self.net_D, False)
172
  self.opt_G.zero_grad()
173
  self.backward_G()
174
- self.opt_G.step()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  self.set_requires_grad(self.net_D, False)
172
  self.opt_G.zero_grad()
173
  self.backward_G()
174
+ self.opt_G.step()
175
+
176
+
177
+ class UNetAuto(nn.Module):
178
+
179
+ def __init__(self, in_channels=1, out_channels=2, features=[64, 128, 256, 512]):
180
+
181
+ super(UNetAuto, self).__init__()
182
+
183
+ self.encoder = nn.ModuleList()
184
+
185
+ self.decoder = nn.ModuleList()
186
+
187
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
188
+
189
+
190
+
191
+ # Encoder part
192
+
193
+ for feature in features:
194
+
195
+ self.encoder.append(self._block(in_channels, feature))
196
+
197
+ in_channels = feature
198
+
199
+
200
+
201
+ # Decoder part (Upsampling)
202
+
203
+ for feature in reversed(features):
204
+
205
+ self.decoder.append(nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2))
206
+
207
+ self.decoder.append(self._block(feature * 2, feature))
208
+
209
+
210
+
211
+ # Final Convolution
212
+
213
+ self.bottleneck = self._block(features[-1], features[-1] * 2)
214
+
215
+ self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
216
+
217
+
218
+
219
+ def forward(self, x): #, t):
220
+
221
+ skip_connections = []
222
+
223
+
224
+
225
+ # Encode
226
+
227
+ for layer in self.encoder:
228
+
229
+ x = layer(x)
230
+
231
+ skip_connections.append(x)
232
+
233
+ x = self.pool(x)
234
+
235
+
236
+
237
+ # Bottleneck
238
+
239
+ x = self.bottleneck(x)
240
+
241
+
242
+
243
+ # Decode
244
+
245
+ skip_connections = skip_connections[::-1]
246
+
247
+ for idx in range(0, len(self.decoder), 2):
248
+
249
+ x = self.decoder[idx](x)
250
+
251
+ skip_connection = skip_connections[idx // 2]
252
+
253
+ x = torch.cat((x, skip_connection), dim=1) # Skip connection
254
+
255
+ x = self.decoder[idx + 1](x)
256
+
257
+
258
+
259
+ return self.final_conv(x)
260
+
261
+
262
+
263
+ def _block(self, in_channels, out_channels):
264
+
265
+ return nn.Sequential(
266
+
267
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
268
+
269
+ nn.BatchNorm2d(out_channels),
270
+
271
+ nn.ReLU(inplace=True),
272
+
273
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
274
+
275
+ nn.BatchNorm2d(out_channels),
276
+
277
+ nn.ReLU(inplace=True),
278
+
279
+ )
280
+
281
+
282
+ class Autoencoder(nn.Module):
283
+
284
+ def __init__(self, model):
285
+
286
+ super(Autoencoder, self).__init__()
287
+
288
+ self.model = model
289
+
290
+
291
+
292
+ def forward(self, x): #, t):
293
+
294
+ return self.model(x)#, t)
utils.py CHANGED
@@ -28,12 +28,12 @@ def build_res_unet(n_input=1, n_output=2, size=256):
28
  net_G = DynamicUnet(body, n_output, (size, size)).to(device)
29
  return net_G
30
 
31
- # def build_mobile_unet(n_input=1, n_output=2, size=256):
32
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
- # mobilenet_model = mobilenet_v2(pretrained=True)
34
- # body = create_body(mobilenet_model, n_in=n_input, cut=-2)
35
- # net_G = DynamicUnet(body, n_output, (size, size)).to(device)
36
- # return net_G
37
 
38
  def create_loss_meters():
39
  loss_D_fake = AverageMeter()
 
28
  net_G = DynamicUnet(body, n_output, (size, size)).to(device)
29
  return net_G
30
 
31
+ def build_mobilenet_unet(n_input=1, n_output=2, size=256):
32
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ mobilenet = mobilenet_v2(pretrained=True)
34
+ body = create_body(mobilenet.features, pretrained=True, n_in=n_input, cut=-2)
35
+ net_G = DynamicUnet(body, n_output, (size, size)).to(device)
36
+ return net_G
37
 
38
  def create_loss_meters():
39
  loss_D_fake = AverageMeter()
weight/autoencoder.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a4231828be0fe2bb7f9701e809917661da56fd2f58a9f19728da0f936f4c2880
3
+ size 124234454