Spaces:
Runtime error
Runtime error
import tensorflow as tf | |
from typing import List | |
from ..commons import read_image | |
from ..augmentation import UnpairedAugmentationFactory | |
class UnpairedLowLightDataset: | |
def __init__( | |
self, | |
image_size: int = 256, | |
apply_resize: bool = False, | |
apply_random_horizontal_flip: bool = True, | |
apply_random_vertical_flip: bool = True, | |
apply_random_rotation: bool = True, | |
) -> None: | |
self.augmentation_factory = UnpairedAugmentationFactory(image_size=image_size) | |
self.image_size = image_size | |
self.apply_resize = apply_resize | |
self.apply_random_horizontal_flip = apply_random_horizontal_flip | |
self.apply_random_vertical_flip = apply_random_vertical_flip | |
self.apply_random_rotation = apply_random_rotation | |
def _resize(self, image): | |
return tf.image.resize(image, (self.image_size, self.image_size)) | |
def _get_dataset(self, images: List[str], batch_size: int, is_train: bool): | |
dataset = tf.data.Dataset.from_tensor_slices((images)) | |
dataset = dataset.map(read_image, num_parallel_calls=tf.data.AUTOTUNE) | |
dataset = ( | |
dataset.map( | |
self.augmentation_factory.random_crop, | |
num_parallel_calls=tf.data.AUTOTUNE, | |
) | |
if not self.apply_resize | |
else dataset.map(self._resize, num_parallel_calls=tf.data.AUTOTUNE) | |
) | |
if is_train: | |
dataset = ( | |
dataset.map( | |
self.augmentation_factory.random_horizontal_flip, | |
num_parallel_calls=tf.data.AUTOTUNE, | |
) | |
if self.apply_random_horizontal_flip | |
else dataset | |
) | |
dataset = ( | |
dataset.map( | |
self.augmentation_factory.random_vertical_flip, | |
num_parallel_calls=tf.data.AUTOTUNE, | |
) | |
if self.apply_random_vertical_flip | |
else dataset | |
) | |
dataset = ( | |
dataset.map( | |
self.augmentation_factory.random_rotate, | |
num_parallel_calls=tf.data.AUTOTUNE, | |
) | |
if self.apply_random_rotation | |
else dataset | |
) | |
dataset = dataset.batch(batch_size, drop_remainder=True) | |
return dataset | |
def get_datasets( | |
self, | |
images: List[str], | |
val_split: float = 0.2, | |
batch_size: int = 16, | |
): | |
split_index = int(len(images) * (1 - val_split)) | |
train_images = images[:split_index] | |
val_images = images[split_index:] | |
print(f"Number of train data points: {len(train_images)}") | |
print(f"Number of validation data points: {len(val_images)}") | |
train_dataset = self._get_dataset(train_images, batch_size, is_train=True) | |
val_dataset = self._get_dataset(val_images, batch_size, is_train=False) | |
return train_dataset, val_dataset | |