Add class filtering to `LoadImagesAndLabels()` dataloader (#5172)
Browse files* Add train class filter feature to datasets.py
Allows for training on a subset of total classes if `include_class` list is defined on datasets.py L448:
```python
include_class = [] # filter labels to include only these classes (optional)
```
* segments fix
- utils/datasets.py +14 -4
utils/datasets.py
CHANGED
@@ -437,10 +437,6 @@ class LoadImagesAndLabels(Dataset):
|
|
437 |
self.shapes = np.array(shapes, dtype=np.float64)
|
438 |
self.img_files = list(cache.keys()) # update
|
439 |
self.label_files = img2label_paths(cache.keys()) # update
|
440 |
-
if single_cls:
|
441 |
-
for x in self.labels:
|
442 |
-
x[:, 0] = 0
|
443 |
-
|
444 |
n = len(shapes) # number of images
|
445 |
bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
|
446 |
nb = bi[-1] + 1 # number of batches
|
@@ -448,6 +444,20 @@ class LoadImagesAndLabels(Dataset):
|
|
448 |
self.n = n
|
449 |
self.indices = range(n)
|
450 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
451 |
# Rectangular Training
|
452 |
if self.rect:
|
453 |
# Sort by aspect ratio
|
|
|
437 |
self.shapes = np.array(shapes, dtype=np.float64)
|
438 |
self.img_files = list(cache.keys()) # update
|
439 |
self.label_files = img2label_paths(cache.keys()) # update
|
|
|
|
|
|
|
|
|
440 |
n = len(shapes) # number of images
|
441 |
bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
|
442 |
nb = bi[-1] + 1 # number of batches
|
|
|
444 |
self.n = n
|
445 |
self.indices = range(n)
|
446 |
|
447 |
+
# Update labels
|
448 |
+
include_class = [] # filter labels to include only these classes (optional)
|
449 |
+
include_class_array = np.array(include_class).reshape(1, -1)
|
450 |
+
for i, (label, segment) in enumerate(zip(self.labels, self.segments)):
|
451 |
+
if include_class:
|
452 |
+
j = (label[:, 0:1] == include_class_array).any(1)
|
453 |
+
self.labels[i] = label[j]
|
454 |
+
if segment:
|
455 |
+
self.segments[i] = segment[j]
|
456 |
+
if single_cls: # single-class training, merge all classes into 0
|
457 |
+
self.labels[i][:, 0] = 0
|
458 |
+
if segment:
|
459 |
+
self.segments[i][:, 0] = 0
|
460 |
+
|
461 |
# Rectangular Training
|
462 |
if self.rect:
|
463 |
# Sort by aspect ratio
|