ChiKyi commited on
Commit
11f1268
1 Parent(s): 68fafaa

update output size

Browse files
Files changed (2) hide show
  1. app.py +11 -8
  2. weight/denoising diffusion.pt +3 -0
app.py CHANGED
@@ -52,12 +52,16 @@ def preprocess_image(image):
52
  image = transforms.ToTensor()(image)[:1] * 2. - 1.
53
  return image
54
 
55
- def postprocess_image(grayscale, prediction):
56
- return lab_to_rgb(grayscale.unsqueeze(0), prediction.cpu())[0]
 
 
 
57
 
58
  # Prediction function with output control
59
  def colorize_image(input_image, mode):
60
  grayscale_image = Image.fromarray(input_image).convert('L')
 
61
  grayscale = preprocess_image(grayscale_image).to(device)
62
 
63
  with torch.no_grad():
@@ -65,9 +69,10 @@ def colorize_image(input_image, mode):
65
  mobilenet_output = mobilenet_model.net_G(grayscale.unsqueeze(0))
66
  autoencoder_output = autoencoder_model(grayscale.unsqueeze(0))
67
 
68
- resnet_colorized = postprocess_image(grayscale, resnet_output)
69
- mobilenet_colorized = postprocess_image(grayscale, mobilenet_output)
70
- autoencoder_colorized = postprocess_image(grayscale, autoencoder_output)
 
71
 
72
  if mode == "ResNet":
73
  return resnet_colorized, None, None
@@ -117,7 +122,6 @@ def gradio_interface():
117
  )
118
 
119
  # Submit logic
120
-
121
  submit_button.click(
122
  fn=colorize_image,
123
  inputs=[input_image, output_modes],
@@ -127,7 +131,6 @@ def gradio_interface():
127
  return demo
128
 
129
 
130
-
131
  # Launch
132
  if __name__ == "__main__":
133
- gradio_interface().launch()
 
52
  image = transforms.ToTensor()(image)[:1] * 2. - 1.
53
  return image
54
 
55
+ def postprocess_image(grayscale, prediction, original_size):
56
+ # Convert Lab back to RGB and resize to the original image size
57
+ colorized_image = lab_to_rgb(grayscale.unsqueeze(0), prediction.cpu())[0]
58
+ colorized_image = Image.fromarray((colorized_image * 255).astype("uint8"))
59
+ return colorized_image.resize(original_size)
60
 
61
  # Prediction function with output control
62
  def colorize_image(input_image, mode):
63
  grayscale_image = Image.fromarray(input_image).convert('L')
64
+ original_size = grayscale_image.size # Store original size
65
  grayscale = preprocess_image(grayscale_image).to(device)
66
 
67
  with torch.no_grad():
 
69
  mobilenet_output = mobilenet_model.net_G(grayscale.unsqueeze(0))
70
  autoencoder_output = autoencoder_model(grayscale.unsqueeze(0))
71
 
72
+ # Resize outputs to match the original size
73
+ resnet_colorized = postprocess_image(grayscale, resnet_output, original_size)
74
+ mobilenet_colorized = postprocess_image(grayscale, mobilenet_output, original_size)
75
+ autoencoder_colorized = postprocess_image(grayscale, autoencoder_output, original_size)
76
 
77
  if mode == "ResNet":
78
  return resnet_colorized, None, None
 
122
  )
123
 
124
  # Submit logic
 
125
  submit_button.click(
126
  fn=colorize_image,
127
  inputs=[input_image, output_modes],
 
131
  return demo
132
 
133
 
 
134
  # Launch
135
  if __name__ == "__main__":
136
+ gradio_interface().launch()
weight/denoising diffusion.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36234b7f753bfaa3c957c2f9cd67e0f4408eb29e479afb984c4df5fd6459d147
3
+ size 124234454