ImageColorization / utils.py
sebastiansarasti's picture
Create utils.py
eb29cc1 verified
raw
history blame
1.14 kB
import torch
import numpy as np
from PIL import Image
from torchvision.transforms import Compose, Resize, Grayscale, ToTensor, ToPILImage
# global variable for the grayscale transform
transform_gs = Compose(
[Resize((128, 128)), Grayscale(num_output_channels=1), ToTensor()]
)
def process_gs_image(image):
"""
Function to process the grayscale image.
"""
# Save original size for later use
original_size = image.size # (width, height)
# Convert the image to grayscale and resize
image = transform_gs(image)
# Add the batch dimension
image = image.unsqueeze(0)
# Return both the processed image and original size
return image, original_size
def inverse_transform_cs(tensor, original_size):
"""
Function to convert the tensor back to the color image and resize it to its original size.
"""
# Convert the tensor back to a PIL image
to_pil = ToPILImage()
pil_image = to_pil(tensor.squeeze(0)) # Remove the batch dimension
# Resize the image back to the original size
pil_image = pil_image.resize(original_size)
return pil_image