tkarras commited on
Commit
1d25833
·
1 Parent(s): f0a4246

Add support for MNIST

Browse files
Files changed (1) hide show
  1. dataset_tool.py +44 -11
dataset_tool.py CHANGED
@@ -13,6 +13,7 @@ import os
13
  import pickle
14
  import sys
15
  import tarfile
 
16
  import zipfile
17
  from pathlib import Path
18
  from typing import Callable, Optional, Tuple, Union
@@ -165,6 +166,36 @@ def open_cifar10(tarball: str, *, max_images: Optional[int]):
165
 
166
  #----------------------------------------------------------------------------
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  def make_transform(
169
  transform: Optional[str],
170
  output_width: Optional[int],
@@ -225,10 +256,11 @@ def open_dataset(source, *, max_images: Optional[int]):
225
  else:
226
  return open_image_folder(source, max_images=max_images)
227
  elif os.path.isfile(source):
228
- if source.endswith('cifar-10-python.tar.gz'):
229
  return open_cifar10(source, max_images=max_images)
230
- ext = file_ext(source)
231
- if ext == 'zip':
 
232
  return open_image_zip(source, max_images=max_images)
233
  else:
234
  assert False, 'unknown archive type'
@@ -293,17 +325,18 @@ def convert_dataset(
293
  The input dataset format is guessed from the --source argument:
294
 
295
  \b
296
- --source *_lmdb/ - Load LSUN dataset
297
- --source cifar-10-python.tar.gz - Load CIFAR-10 dataset
298
- --source path/ - Recursively load all images from path/
299
- --source dataset.zip - Recursively load all images from dataset.zip
 
300
 
301
- The output dataset format can be either an image folder or a zip archive. Specifying
302
- the output format and path:
303
 
304
  \b
305
- --dest /path/to/dir - Save output files under /path/to/dir
306
- --dest /path/to/dataset.zip - Save output files into /path/to/dataset.zip archive
307
 
308
  Images within the dataset archive will be stored as uncompressed PNG.
309
 
 
13
  import pickle
14
  import sys
15
  import tarfile
16
+ import gzip
17
  import zipfile
18
  from pathlib import Path
19
  from typing import Callable, Optional, Tuple, Union
 
166
 
167
  #----------------------------------------------------------------------------
168
 
169
+ def open_mnist(images_gz: str, *, max_images: Optional[int]):
170
+ labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz')
171
+ assert labels_gz != images_gz
172
+ images = []
173
+ labels = []
174
+
175
+ with gzip.open(images_gz, 'rb') as f:
176
+ images = np.frombuffer(f.read(), np.uint8, offset=16)
177
+ with gzip.open(labels_gz, 'rb') as f:
178
+ labels = np.frombuffer(f.read(), np.uint8, offset=8)
179
+
180
+ images = images.reshape(-1, 28, 28)
181
+ images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0)
182
+ assert images.shape == (60000, 32, 32) and images.dtype == np.uint8
183
+ assert labels.shape == (60000,) and labels.dtype == np.uint8
184
+ assert np.min(images) == 0 and np.max(images) == 255
185
+ assert np.min(labels) == 0 and np.max(labels) == 9
186
+
187
+ max_idx = maybe_min(len(images), max_images)
188
+
189
+ def iterate_images():
190
+ for idx, img in enumerate(images):
191
+ yield dict(img=img, label=int(labels[idx]))
192
+ if idx >= max_idx-1:
193
+ break
194
+
195
+ return max_idx, iterate_images()
196
+
197
+ #----------------------------------------------------------------------------
198
+
199
  def make_transform(
200
  transform: Optional[str],
201
  output_width: Optional[int],
 
256
  else:
257
  return open_image_folder(source, max_images=max_images)
258
  elif os.path.isfile(source):
259
+ if os.path.basename(source) == 'cifar-10-python.tar.gz':
260
  return open_cifar10(source, max_images=max_images)
261
+ elif os.path.basename(source) == 'train-images-idx3-ubyte.gz':
262
+ return open_mnist(source, max_images=max_images)
263
+ elif file_ext(source) == 'zip':
264
  return open_image_zip(source, max_images=max_images)
265
  else:
266
  assert False, 'unknown archive type'
 
325
  The input dataset format is guessed from the --source argument:
326
 
327
  \b
328
+ --source *_lmdb/ Load LSUN dataset
329
+ --source cifar-10-python.tar.gz Load CIFAR-10 dataset
330
+ --source train-images-idx3-ubyte.gz Load MNIST dataset
331
+ --source path/ Recursively load all images from path/
332
+ --source dataset.zip Recursively load all images from dataset.zip
333
 
334
+ The output dataset format can be either an image folder or a zip archive.
335
+ Specifying the output format and path:
336
 
337
  \b
338
+ --dest /path/to/dir Save output files under /path/to/dir
339
+ --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip
340
 
341
  Images within the dataset archive will be stored as uncompressed PNG.
342