|
import os |
|
from matplotlib import gridspec |
|
import matplotlib.pylab as plt |
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
|
|
def crop_center(image): |
|
"""Returns a cropped square image.""" |
|
shape = image.shape |
|
new_shape = min(shape[1], shape[2]) |
|
offset_y = max(shape[1] - shape[2], 0) // 2 |
|
offset_x = max(shape[2] - shape[1], 0) // 2 |
|
image = tf.image.crop_to_bounding_box( |
|
image, offset_y, offset_x, new_shape, new_shape) |
|
return image |
|
|
|
def load_image(image_url, image_size=(256, 256), preserve_aspect_ratio=True): |
|
"""Loads and preprocesses images.""" |
|
|
|
image_path = tf.keras.utils.get_file(os.path.basename(image_url)[-128:], image_url) |
|
|
|
img = plt.imread(image_path).astype(np.float32)[np.newaxis, ...] |
|
if img.max() > 1.0: |
|
img = img / 255. |
|
if len(img.shape) == 3: |
|
img = tf.stack([img, img, img], axis=-1) |
|
img = crop_center(img) |
|
img = tf.image.resize(img, image_size, preserve_aspect_ratio=True) |
|
return img |