import tensorflow as tf import numpy as np from PIL import Image def load_img(path_to_img): max_dim = 512 img = tf.io.read_file(path_to_img) img = tf.image.decode_image(img) img = tf.image.convert_image_dtype(img, tf.float32) shape = tf.cast(tf.shape(img)[:-1], tf.float32) long_dim = max(shape) scale = max_dim / long_dim new_shape = tf.cast(shape * scale, tf.int32) img = tf.image.resize(img, new_shape) img = img[tf.newaxis, :] return img def transform_img(img): max_dim = 512 img = tf.image.decode_image(img) img = tf.image.convert_image_dtype(img, tf.float32) shape = tf.cast(tf.shape(img)[:-1], tf.float32) long_dim = max(shape) scale = max_dim / long_dim new_shape = tf.cast(shape * scale, tf.int32) img = tf.image.resize(img, new_shape) img = img[tf.newaxis, :] return img def imshow(image, title=None): if len(image.shape) > 3: image = tf.squeeze(image, axis=0) image = np.squeeze(image) return image def tensor_to_image(tensor): tensor = tensor * 255 tensor = np.array(tensor, np.uint8) if np.ndim(tensor) > 3: assert tensor.shape[0] == 1 tensor = tensor[0] return Image.fromarray(tensor)