glenn-jocher commited on
Commit
d61930e
·
1 Parent(s): 0fda95a

Improved corruption handling during scan and cache (#999)

Browse files
Files changed (1) hide show
  1. utils/datasets.py +19 -15
utils/datasets.py CHANGED
@@ -328,6 +328,12 @@ class LoadStreams: # multiple IP or RTSP cameras
328
  class LoadImagesAndLabels(Dataset): # for training/testing
329
  def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
330
  cache_images=False, single_cls=False, stride=32, pad=0.0, rank=-1):
 
 
 
 
 
 
331
  try:
332
  f = [] # image files
333
  for p in path if isinstance(path, list) else [path]:
@@ -362,11 +368,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing
362
  self.mosaic_border = [-img_size // 2, -img_size // 2]
363
  self.stride = stride
364
 
365
- # Define labels
366
- sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
367
- self.label_files = [x.replace(sa, sb, 1).replace(os.path.splitext(x)[-1], '.txt') for x in self.img_files]
368
-
369
  # Check cache
 
370
  cache_path = str(Path(self.label_files[0]).parent) + '.cache' # cached labels
371
  if os.path.isfile(cache_path):
372
  cache = torch.load(cache_path) # load
@@ -375,12 +378,15 @@ class LoadImagesAndLabels(Dataset): # for training/testing
375
  else:
376
  cache = self.cache_labels(cache_path) # cache
377
 
378
- # Get labels
379
- labels, shapes = zip(*[cache[x] for x in self.img_files])
380
- self.shapes = np.array(shapes, dtype=np.float64)
381
  self.labels = list(labels)
 
 
 
382
 
383
- # Rectangular Training https://github.com/ultralytics/yolov3/issues/232
384
  if self.rect:
385
  # Sort by aspect ratio
386
  s = self.shapes # wh
@@ -404,7 +410,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
404
 
405
  self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
406
 
407
- # Cache labels
408
  create_datasubset, extract_bounding_boxes, labels_loaded = False, False, False
409
  nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number missing, found, empty, datasubset, duplicate
410
  pbar = enumerate(self.label_files)
@@ -483,10 +489,9 @@ class LoadImagesAndLabels(Dataset): # for training/testing
483
  for (img, label) in pbar:
484
  try:
485
  l = []
486
- image = Image.open(img)
487
- image.verify() # PIL verify
488
- # _ = io.imread(img) # skimage verify (from skimage import io)
489
- shape = exif_size(image) # image size
490
  assert (shape[0] > 9) & (shape[1] > 9), 'image size <10 pixels'
491
  if os.path.isfile(label):
492
  with open(label, 'r') as f:
@@ -495,8 +500,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
495
  l = np.zeros((0, 5), dtype=np.float32)
496
  x[img] = [l, shape]
497
  except Exception as e:
498
- x[img] = [None, None]
499
- print('WARNING: %s: %s' % (img, e))
500
 
501
  x['hash'] = get_hash(self.label_files + self.img_files)
502
  torch.save(x, path) # save for next time
 
328
  class LoadImagesAndLabels(Dataset): # for training/testing
329
  def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
330
  cache_images=False, single_cls=False, stride=32, pad=0.0, rank=-1):
331
+
332
+ def img2label_paths(img_paths):
333
+ # Define label paths as a function of image paths
334
+ sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
335
+ return [x.replace(sa, sb, 1).replace(os.path.splitext(x)[-1], '.txt') for x in img_paths]
336
+
337
  try:
338
  f = [] # image files
339
  for p in path if isinstance(path, list) else [path]:
 
368
  self.mosaic_border = [-img_size // 2, -img_size // 2]
369
  self.stride = stride
370
 
 
 
 
 
371
  # Check cache
372
+ self.label_files = img2label_paths(self.img_files) # labels
373
  cache_path = str(Path(self.label_files[0]).parent) + '.cache' # cached labels
374
  if os.path.isfile(cache_path):
375
  cache = torch.load(cache_path) # load
 
378
  else:
379
  cache = self.cache_labels(cache_path) # cache
380
 
381
+ # Read cache
382
+ cache.pop('hash') # remove hash
383
+ labels, shapes = zip(*cache.values())
384
  self.labels = list(labels)
385
+ self.shapes = np.array(shapes, dtype=np.float64)
386
+ self.img_files = list(cache.keys()) # update
387
+ self.label_files = img2label_paths(cache.keys()) # update
388
 
389
+ # Rectangular Training
390
  if self.rect:
391
  # Sort by aspect ratio
392
  s = self.shapes # wh
 
410
 
411
  self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
412
 
413
+ # Check labels
414
  create_datasubset, extract_bounding_boxes, labels_loaded = False, False, False
415
  nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number missing, found, empty, datasubset, duplicate
416
  pbar = enumerate(self.label_files)
 
489
  for (img, label) in pbar:
490
  try:
491
  l = []
492
+ im = Image.open(img)
493
+ im.verify() # PIL verify
494
+ shape = exif_size(im) # image size
 
495
  assert (shape[0] > 9) & (shape[1] > 9), 'image size <10 pixels'
496
  if os.path.isfile(label):
497
  with open(label, 'r') as f:
 
500
  l = np.zeros((0, 5), dtype=np.float32)
501
  x[img] = [l, shape]
502
  except Exception as e:
503
+ print('WARNING: Ignoring corrupted image and/or label:%s: %s' % (img, e))
 
504
 
505
  x['hash'] = get_hash(self.label_files + self.img_files)
506
  torch.save(x, path) # save for next time