Spaces:
Build error
Build error
File size: 4,834 Bytes
81170fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import tensorflow as tf
import numpy as np
from PIL import Image
from typing import Sequence
from tqdm import tqdm
import argparse
import json
import os
import logging
logger = logging.getLogger(__name__)
def images_to_tfrecords(image_dir, data_dir, has_labels):
"""
Converts a folder of images to a TFRecord file.
The image directory should have one of the following structures:
If has_labels = False, image_dir should look like this:
path/to/image_dir/
0.jpg
1.jpg
2.jpg
4.jpg
...
If has_labels = True, image_dir should look like this:
path/to/image_dir/
label0/
0.jpg
1.jpg
...
label1/
a.jpg
b.jpg
c.jpg
...
...
The labels will be label0 -> 0, label1 -> 1.
Args:
image_dir (str): Path to images.
data_dir (str): Path where the TFrecords dataset is stored.
has_labels (bool): If True, 'image_dir' contains label directories.
Returns:
(dict): Dataset info.
"""
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
os.makedirs(data_dir, exist_ok=True)
writer = tf.io.TFRecordWriter(os.path.join(data_dir, 'dataset.tfrecords'))
num_examples = 0
num_classes = 0
if has_labels:
for label_dir in os.listdir(image_dir):
if not os.path.isdir(os.path.join(image_dir, label_dir)):
logger.warning('The image directory should contain one directory for each label.')
logger.warning('These label directories should contain the image files.')
if os.path.exists(os.path.join(data_dir, 'dataset.tfrecords')):
os.remove(os.path.join(data_dir, 'dataset.tfrecords'))
return
for img_file in tqdm(os.listdir(os.path.join(image_dir, label_dir))):
file_format = img_file[img_file.rfind('.') + 1:]
if file_format not in ['png', 'jpg', 'jpeg']:
continue
#img = Image.open(os.path.join(image_dir, label_dir, img_file)).resize(img_size)
img = Image.open(os.path.join(image_dir, label_dir, img_file))
img = np.array(img, dtype=np.uint8)
height = img.shape[0]
width = img.shape[1]
channels = img.shape[2]
img_encoded = img.tobytes()
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(height),
'width': _int64_feature(width),
'channels': _int64_feature(channels),
'image': _bytes_feature(img_encoded),
'label': _int64_feature(num_classes)}))
writer.write(example.SerializeToString())
num_examples += 1
num_classes += 1
else:
for img_file in tqdm(os.listdir(os.path.join(image_dir))):
file_format = img_file[img_file.rfind('.') + 1:]
if file_format not in ['png', 'jpg', 'jpeg']:
continue
#img = Image.open(os.path.join(image_dir, label_dir, img_file)).resize(img_size)
img = Image.open(os.path.join(image_dir, img_file))
img = np.array(img, dtype=np.uint8)
height = img.shape[0]
width = img.shape[1]
channels = img.shape[2]
img_encoded = img.tobytes()
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(height),
'width': _int64_feature(width),
'channels': _int64_feature(channels),
'image': _bytes_feature(img_encoded),
'label': _int64_feature(num_classes)})) # dummy label
writer.write(example.SerializeToString())
num_examples += 1
writer.close()
dataset_info = {'num_examples': num_examples, 'num_classes': num_classes}
with open(os.path.join(data_dir, 'dataset_info.json'), 'w') as fout:
json.dump(dataset_info, fout)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image_dir', type=str, help='Path to the image directory.')
parser.add_argument('--data_dir', type=str, help='Path where the TFRecords dataset is stored.')
parser.add_argument('--has_labels', action='store_true', help='If True, image_dir contains label directories.')
args = parser.parse_args()
images_to_tfrecords(args.image_dir, args.data_dir, args.has_labels)
|