Commit
·
d61930e
1
Parent(s):
0fda95a
Improved corruption handling during scan and cache (#999)
Browse files- utils/datasets.py +19 -15
utils/datasets.py
CHANGED
@@ -328,6 +328,12 @@ class LoadStreams: # multiple IP or RTSP cameras
|
|
328 |
class LoadImagesAndLabels(Dataset): # for training/testing
|
329 |
def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
|
330 |
cache_images=False, single_cls=False, stride=32, pad=0.0, rank=-1):
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
try:
|
332 |
f = [] # image files
|
333 |
for p in path if isinstance(path, list) else [path]:
|
@@ -362,11 +368,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|
362 |
self.mosaic_border = [-img_size // 2, -img_size // 2]
|
363 |
self.stride = stride
|
364 |
|
365 |
-
# Define labels
|
366 |
-
sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
|
367 |
-
self.label_files = [x.replace(sa, sb, 1).replace(os.path.splitext(x)[-1], '.txt') for x in self.img_files]
|
368 |
-
|
369 |
# Check cache
|
|
|
370 |
cache_path = str(Path(self.label_files[0]).parent) + '.cache' # cached labels
|
371 |
if os.path.isfile(cache_path):
|
372 |
cache = torch.load(cache_path) # load
|
@@ -375,12 +378,15 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|
375 |
else:
|
376 |
cache = self.cache_labels(cache_path) # cache
|
377 |
|
378 |
-
#
|
379 |
-
|
380 |
-
|
381 |
self.labels = list(labels)
|
|
|
|
|
|
|
382 |
|
383 |
-
# Rectangular Training
|
384 |
if self.rect:
|
385 |
# Sort by aspect ratio
|
386 |
s = self.shapes # wh
|
@@ -404,7 +410,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|
404 |
|
405 |
self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
|
406 |
|
407 |
-
#
|
408 |
create_datasubset, extract_bounding_boxes, labels_loaded = False, False, False
|
409 |
nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number missing, found, empty, datasubset, duplicate
|
410 |
pbar = enumerate(self.label_files)
|
@@ -483,10 +489,9 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|
483 |
for (img, label) in pbar:
|
484 |
try:
|
485 |
l = []
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
shape = exif_size(image) # image size
|
490 |
assert (shape[0] > 9) & (shape[1] > 9), 'image size <10 pixels'
|
491 |
if os.path.isfile(label):
|
492 |
with open(label, 'r') as f:
|
@@ -495,8 +500,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|
495 |
l = np.zeros((0, 5), dtype=np.float32)
|
496 |
x[img] = [l, shape]
|
497 |
except Exception as e:
|
498 |
-
|
499 |
-
print('WARNING: %s: %s' % (img, e))
|
500 |
|
501 |
x['hash'] = get_hash(self.label_files + self.img_files)
|
502 |
torch.save(x, path) # save for next time
|
|
|
328 |
class LoadImagesAndLabels(Dataset): # for training/testing
|
329 |
def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
|
330 |
cache_images=False, single_cls=False, stride=32, pad=0.0, rank=-1):
|
331 |
+
|
332 |
+
def img2label_paths(img_paths):
|
333 |
+
# Define label paths as a function of image paths
|
334 |
+
sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
|
335 |
+
return [x.replace(sa, sb, 1).replace(os.path.splitext(x)[-1], '.txt') for x in img_paths]
|
336 |
+
|
337 |
try:
|
338 |
f = [] # image files
|
339 |
for p in path if isinstance(path, list) else [path]:
|
|
|
368 |
self.mosaic_border = [-img_size // 2, -img_size // 2]
|
369 |
self.stride = stride
|
370 |
|
|
|
|
|
|
|
|
|
371 |
# Check cache
|
372 |
+
self.label_files = img2label_paths(self.img_files) # labels
|
373 |
cache_path = str(Path(self.label_files[0]).parent) + '.cache' # cached labels
|
374 |
if os.path.isfile(cache_path):
|
375 |
cache = torch.load(cache_path) # load
|
|
|
378 |
else:
|
379 |
cache = self.cache_labels(cache_path) # cache
|
380 |
|
381 |
+
# Read cache
|
382 |
+
cache.pop('hash') # remove hash
|
383 |
+
labels, shapes = zip(*cache.values())
|
384 |
self.labels = list(labels)
|
385 |
+
self.shapes = np.array(shapes, dtype=np.float64)
|
386 |
+
self.img_files = list(cache.keys()) # update
|
387 |
+
self.label_files = img2label_paths(cache.keys()) # update
|
388 |
|
389 |
+
# Rectangular Training
|
390 |
if self.rect:
|
391 |
# Sort by aspect ratio
|
392 |
s = self.shapes # wh
|
|
|
410 |
|
411 |
self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
|
412 |
|
413 |
+
# Check labels
|
414 |
create_datasubset, extract_bounding_boxes, labels_loaded = False, False, False
|
415 |
nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number missing, found, empty, datasubset, duplicate
|
416 |
pbar = enumerate(self.label_files)
|
|
|
489 |
for (img, label) in pbar:
|
490 |
try:
|
491 |
l = []
|
492 |
+
im = Image.open(img)
|
493 |
+
im.verify() # PIL verify
|
494 |
+
shape = exif_size(im) # image size
|
|
|
495 |
assert (shape[0] > 9) & (shape[1] > 9), 'image size <10 pixels'
|
496 |
if os.path.isfile(label):
|
497 |
with open(label, 'r') as f:
|
|
|
500 |
l = np.zeros((0, 5), dtype=np.float32)
|
501 |
x[img] = [l, shape]
|
502 |
except Exception as e:
|
503 |
+
print('WARNING: Ignoring corrupted image and/or label:%s: %s' % (img, e))
|
|
|
504 |
|
505 |
x['hash'] = get_hash(self.label_files + self.img_files)
|
506 |
torch.save(x, path) # save for next time
|