glenn-jocher commited on
Commit
a346926
·
unverified ·
1 Parent(s): b754525

Add class filtering to `LoadImagesAndLabels()` dataloader (#5172)

Browse files

* Add train class filter feature to datasets.py

Allows for training on a subset of total classes if `include_class` list is defined on datasets.py L448:
```python
include_class = [] # filter labels to include only these classes (optional)
```

* segments fix

Files changed (1) hide show
  1. utils/datasets.py +14 -4
utils/datasets.py CHANGED
@@ -437,10 +437,6 @@ class LoadImagesAndLabels(Dataset):
437
  self.shapes = np.array(shapes, dtype=np.float64)
438
  self.img_files = list(cache.keys()) # update
439
  self.label_files = img2label_paths(cache.keys()) # update
440
- if single_cls:
441
- for x in self.labels:
442
- x[:, 0] = 0
443
-
444
  n = len(shapes) # number of images
445
  bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
446
  nb = bi[-1] + 1 # number of batches
@@ -448,6 +444,20 @@ class LoadImagesAndLabels(Dataset):
448
  self.n = n
449
  self.indices = range(n)
450
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
  # Rectangular Training
452
  if self.rect:
453
  # Sort by aspect ratio
 
437
  self.shapes = np.array(shapes, dtype=np.float64)
438
  self.img_files = list(cache.keys()) # update
439
  self.label_files = img2label_paths(cache.keys()) # update
 
 
 
 
440
  n = len(shapes) # number of images
441
  bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
442
  nb = bi[-1] + 1 # number of batches
 
444
  self.n = n
445
  self.indices = range(n)
446
 
447
+ # Update labels
448
+ include_class = [] # filter labels to include only these classes (optional)
449
+ include_class_array = np.array(include_class).reshape(1, -1)
450
+ for i, (label, segment) in enumerate(zip(self.labels, self.segments)):
451
+ if include_class:
452
+ j = (label[:, 0:1] == include_class_array).any(1)
453
+ self.labels[i] = label[j]
454
+ if segment:
455
+ self.segments[i] = segment[j]
456
+ if single_cls: # single-class training, merge all classes into 0
457
+ self.labels[i][:, 0] = 0
458
+ if segment:
459
+ self.segments[i][:, 0] = 0
460
+
461
  # Rectangular Training
462
  if self.rect:
463
  # Sort by aspect ratio