ChiKyi commited on
Commit
01c3f1c
1 Parent(s): 11f1268
Files changed (1) hide show
  1. app.py +12 -12
app.py CHANGED
@@ -11,7 +11,7 @@ from utils import lab_to_rgb, build_res_unet, build_mobilenet_unet # Utility to
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
 
13
  # Hàm load models
14
- def load_autoencoder_model(auto_model_path):
15
  unet = UNetAuto(in_channels=1, out_channels=2).to(device)
16
  model = Autoencoder(unet).to(device)
17
  model.load_state_dict(torch.load(auto_model_path, map_location=device))
@@ -44,7 +44,7 @@ mobilenet_model = load_model(
44
  model_type='mobilenet'
45
  )
46
 
47
- autoencoder_model = load_autoencoder_model("weight/autoencoder.pt")
48
 
49
  # Transformations
50
  def preprocess_image(image):
@@ -67,21 +67,21 @@ def colorize_image(input_image, mode):
67
  with torch.no_grad():
68
  resnet_output = resnet_model.net_G(grayscale.unsqueeze(0))
69
  mobilenet_output = mobilenet_model.net_G(grayscale.unsqueeze(0))
70
- autoencoder_output = autoencoder_model(grayscale.unsqueeze(0))
71
 
72
  # Resize outputs to match the original size
73
  resnet_colorized = postprocess_image(grayscale, resnet_output, original_size)
74
  mobilenet_colorized = postprocess_image(grayscale, mobilenet_output, original_size)
75
- autoencoder_colorized = postprocess_image(grayscale, autoencoder_output, original_size)
76
 
77
  if mode == "ResNet":
78
  return resnet_colorized, None, None
79
  elif mode == "MobileNet":
80
  return None, mobilenet_colorized, None
81
- elif mode == "Autoencoder":
82
- return None, None, autoencoder_colorized
83
  elif mode == "Comparison":
84
- return resnet_colorized, mobilenet_colorized, autoencoder_colorized
85
 
86
 
87
  # Gradio Interface
@@ -90,7 +90,7 @@ def gradio_interface():
90
  # Input components
91
  input_image = gr.Image(type="numpy", label="Upload an Image")
92
  output_modes = gr.Radio(
93
- choices=["ResNet", "MobileNet", "Autoencoder", "Comparison"],
94
  value="ResNet",
95
  label="Output Mode"
96
  )
@@ -101,7 +101,7 @@ def gradio_interface():
101
  with gr.Row(): # Place output images in a single row
102
  resnet_output = gr.Image(label="Colorized Image (ResNet18)", visible=False)
103
  mobilenet_output = gr.Image(label="Colorized Image (MobileNet)", visible=False)
104
- autoencoder_output = gr.Image(label="Colorized Image (Autoencoder)", visible=False)
105
 
106
  # Output mode logic
107
  def update_visibility(mode):
@@ -109,7 +109,7 @@ def gradio_interface():
109
  return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
110
  elif mode == "MobileNet":
111
  return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
112
- elif mode == "Autoencoder":
113
  return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
114
  elif mode == "Comparison":
115
  return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
@@ -118,14 +118,14 @@ def gradio_interface():
118
  output_modes.change(
119
  fn=update_visibility,
120
  inputs=[output_modes],
121
- outputs=[resnet_output, mobilenet_output, autoencoder_output]
122
  )
123
 
124
  # Submit logic
125
  submit_button.click(
126
  fn=colorize_image,
127
  inputs=[input_image, output_modes],
128
- outputs=[resnet_output, mobilenet_output, autoencoder_output]
129
  )
130
 
131
  return demo
 
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
 
13
  # Hàm load models
14
+ def load_unet_model(auto_model_path):
15
  unet = UNetAuto(in_channels=1, out_channels=2).to(device)
16
  model = Autoencoder(unet).to(device)
17
  model.load_state_dict(torch.load(auto_model_path, map_location=device))
 
44
  model_type='mobilenet'
45
  )
46
 
47
+ unet_model = load_unet_model("weight/autoencoder.pt")
48
 
49
  # Transformations
50
  def preprocess_image(image):
 
67
  with torch.no_grad():
68
  resnet_output = resnet_model.net_G(grayscale.unsqueeze(0))
69
  mobilenet_output = mobilenet_model.net_G(grayscale.unsqueeze(0))
70
+ unet_output = unet_model(grayscale.unsqueeze(0))
71
 
72
  # Resize outputs to match the original size
73
  resnet_colorized = postprocess_image(grayscale, resnet_output, original_size)
74
  mobilenet_colorized = postprocess_image(grayscale, mobilenet_output, original_size)
75
+ unet_colorized = postprocess_image(grayscale, unet_output, original_size)
76
 
77
  if mode == "ResNet":
78
  return resnet_colorized, None, None
79
  elif mode == "MobileNet":
80
  return None, mobilenet_colorized, None
81
+ elif mode == "Unet":
82
+ return None, None, unet_colorized
83
  elif mode == "Comparison":
84
+ return resnet_colorized, mobilenet_colorized, unet_colorized
85
 
86
 
87
  # Gradio Interface
 
90
  # Input components
91
  input_image = gr.Image(type="numpy", label="Upload an Image")
92
  output_modes = gr.Radio(
93
+ choices=["ResNet", "MobileNet", "Unet", "Comparison"],
94
  value="ResNet",
95
  label="Output Mode"
96
  )
 
101
  with gr.Row(): # Place output images in a single row
102
  resnet_output = gr.Image(label="Colorized Image (ResNet18)", visible=False)
103
  mobilenet_output = gr.Image(label="Colorized Image (MobileNet)", visible=False)
104
+ unet_output = gr.Image(label="Colorized Image (Unet)", visible=False)
105
 
106
  # Output mode logic
107
  def update_visibility(mode):
 
109
  return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
110
  elif mode == "MobileNet":
111
  return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
112
+ elif mode == "Unet":
113
  return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
114
  elif mode == "Comparison":
115
  return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
 
118
  output_modes.change(
119
  fn=update_visibility,
120
  inputs=[output_modes],
121
+ outputs=[resnet_output, mobilenet_output, unet_output]
122
  )
123
 
124
  # Submit logic
125
  submit_button.click(
126
  fn=colorize_image,
127
  inputs=[input_image, output_modes],
128
+ outputs=[resnet_output, mobilenet_output, unet_output]
129
  )
130
 
131
  return demo