Spaces:
Sleeping
Sleeping
update output size
Browse files- app.py +11 -8
- weight/denoising diffusion.pt +3 -0
app.py
CHANGED
@@ -52,12 +52,16 @@ def preprocess_image(image):
|
|
52 |
image = transforms.ToTensor()(image)[:1] * 2. - 1.
|
53 |
return image
|
54 |
|
55 |
-
def postprocess_image(grayscale, prediction):
|
56 |
-
|
|
|
|
|
|
|
57 |
|
58 |
# Prediction function with output control
|
59 |
def colorize_image(input_image, mode):
|
60 |
grayscale_image = Image.fromarray(input_image).convert('L')
|
|
|
61 |
grayscale = preprocess_image(grayscale_image).to(device)
|
62 |
|
63 |
with torch.no_grad():
|
@@ -65,9 +69,10 @@ def colorize_image(input_image, mode):
|
|
65 |
mobilenet_output = mobilenet_model.net_G(grayscale.unsqueeze(0))
|
66 |
autoencoder_output = autoencoder_model(grayscale.unsqueeze(0))
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
71 |
|
72 |
if mode == "ResNet":
|
73 |
return resnet_colorized, None, None
|
@@ -117,7 +122,6 @@ def gradio_interface():
|
|
117 |
)
|
118 |
|
119 |
# Submit logic
|
120 |
-
|
121 |
submit_button.click(
|
122 |
fn=colorize_image,
|
123 |
inputs=[input_image, output_modes],
|
@@ -127,7 +131,6 @@ def gradio_interface():
|
|
127 |
return demo
|
128 |
|
129 |
|
130 |
-
|
131 |
# Launch
|
132 |
if __name__ == "__main__":
|
133 |
-
gradio_interface().launch()
|
|
|
52 |
image = transforms.ToTensor()(image)[:1] * 2. - 1.
|
53 |
return image
|
54 |
|
55 |
+
def postprocess_image(grayscale, prediction, original_size):
|
56 |
+
# Convert Lab back to RGB and resize to the original image size
|
57 |
+
colorized_image = lab_to_rgb(grayscale.unsqueeze(0), prediction.cpu())[0]
|
58 |
+
colorized_image = Image.fromarray((colorized_image * 255).astype("uint8"))
|
59 |
+
return colorized_image.resize(original_size)
|
60 |
|
61 |
# Prediction function with output control
|
62 |
def colorize_image(input_image, mode):
|
63 |
grayscale_image = Image.fromarray(input_image).convert('L')
|
64 |
+
original_size = grayscale_image.size # Store original size
|
65 |
grayscale = preprocess_image(grayscale_image).to(device)
|
66 |
|
67 |
with torch.no_grad():
|
|
|
69 |
mobilenet_output = mobilenet_model.net_G(grayscale.unsqueeze(0))
|
70 |
autoencoder_output = autoencoder_model(grayscale.unsqueeze(0))
|
71 |
|
72 |
+
# Resize outputs to match the original size
|
73 |
+
resnet_colorized = postprocess_image(grayscale, resnet_output, original_size)
|
74 |
+
mobilenet_colorized = postprocess_image(grayscale, mobilenet_output, original_size)
|
75 |
+
autoencoder_colorized = postprocess_image(grayscale, autoencoder_output, original_size)
|
76 |
|
77 |
if mode == "ResNet":
|
78 |
return resnet_colorized, None, None
|
|
|
122 |
)
|
123 |
|
124 |
# Submit logic
|
|
|
125 |
submit_button.click(
|
126 |
fn=colorize_image,
|
127 |
inputs=[input_image, output_modes],
|
|
|
131 |
return demo
|
132 |
|
133 |
|
|
|
134 |
# Launch
|
135 |
if __name__ == "__main__":
|
136 |
+
gradio_interface().launch()
|
weight/denoising diffusion.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:36234b7f753bfaa3c957c2f9cd67e0f4408eb29e479afb984c4df5fd6459d147
|
3 |
+
size 124234454
|