💬 [Update] Bbox drawing tools, change label text
Browse files- yolo/config/config.py +1 -0
- yolo/config/general.yaml +1 -0
- yolo/tools/data_augmentation.py +1 -1
- yolo/tools/solver.py +7 -1
- yolo/utils/bounding_box_utils.py +2 -2
yolo/config/config.py
CHANGED
@@ -140,6 +140,7 @@ class Config:
|
|
140 |
cpu_num: int
|
141 |
|
142 |
class_num: int
|
|
|
143 |
image_size: List[int]
|
144 |
|
145 |
out_path: str
|
|
|
140 |
cpu_num: int
|
141 |
|
142 |
class_num: int
|
143 |
+
class_list: List[str]
|
144 |
image_size: List[int]
|
145 |
|
146 |
out_path: str
|
yolo/config/general.yaml
CHANGED
@@ -2,6 +2,7 @@ device: 0
|
|
2 |
cpu_num: 16
|
3 |
|
4 |
class_num: 80
|
|
|
5 |
image_size: [640, 640]
|
6 |
|
7 |
out_path: runs
|
|
|
2 |
cpu_num: 16
|
3 |
|
4 |
class_num: 80
|
5 |
+
class_list: ['Person', 'Bicycle', 'Car', 'Motorcycle', 'Airplane', 'Bus', 'Train', 'Truck', 'Boat', 'Traffic light', 'Fire hydrant', 'Stop sign', 'Parking meter', 'Bench', 'Bird', 'Cat', 'Dog', 'Horse', 'Sheep', 'Cow', 'Elephant', 'Bear', 'Zebra', 'Giraffe', 'Backpack', 'Umbrella', 'Handbag', 'Tie', 'Suitcase', 'Frisbee', 'Skis', 'Snowboard', 'Sports ball', 'Kite', 'Baseball bat', 'Baseball glove', 'Skateboard', 'Surfboard', 'Tennis racket', 'Bottle', 'Wine glass', 'Cup', 'Fork', 'Knife', 'Spoon', 'Bowl', 'Banana', 'Apple', 'Sandwich', 'Orange', 'Broccoli', 'Carrot', 'Hot dog', 'Pizza', 'Donut', 'Cake', 'Chair', 'Couch', 'Potted plant', 'Bed', 'Dining table', 'Toilet', 'Tv', 'Laptop', 'Mouse', 'Remote', 'Keyboard', 'Cell phone', 'Microwave', 'Oven', 'Toaster', 'Sink', 'Refrigerator', 'Book', 'Clock', 'Vase', 'Scissors', 'Teddy bear', 'Hair drier', 'Toothbrush']
|
6 |
image_size: [640, 640]
|
7 |
|
8 |
out_path: runs
|
yolo/tools/data_augmentation.py
CHANGED
@@ -9,7 +9,7 @@ class AugmentationComposer:
|
|
9 |
|
10 |
def __init__(self, transforms, image_size: int = 640):
|
11 |
self.transforms = transforms
|
12 |
-
self.image_size = image_size
|
13 |
self.pad_resize = PadAndResize(self.image_size)
|
14 |
|
15 |
for transform in self.transforms:
|
|
|
9 |
|
10 |
def __init__(self, transforms, image_size: int = 640):
|
11 |
self.transforms = transforms
|
12 |
+
self.image_size = image_size[0]
|
13 |
self.pad_resize = PadAndResize(self.image_size)
|
14 |
|
15 |
for transform in self.transforms:
|
yolo/tools/solver.py
CHANGED
@@ -106,6 +106,7 @@ class ModelTester:
|
|
106 |
|
107 |
self.anchor2box = AnchorBoxConverter(cfg.model, cfg.image_size, device)
|
108 |
self.nms = cfg.task.nms
|
|
|
109 |
self.save_path = save_path
|
110 |
|
111 |
def solve(self, dataloader: StreamDataLoader):
|
@@ -119,7 +120,12 @@ class ModelTester:
|
|
119 |
predict, _ = self.anchor2box(raw_output[0][3:], with_logits=True)
|
120 |
nms_out = bbox_nms(predict, self.nms)
|
121 |
draw_bboxes(
|
122 |
-
images[0],
|
|
|
|
|
|
|
|
|
|
|
123 |
)
|
124 |
except (KeyboardInterrupt, Exception) as e:
|
125 |
dataloader.stop_event.set()
|
|
|
106 |
|
107 |
self.anchor2box = AnchorBoxConverter(cfg.model, cfg.image_size, device)
|
108 |
self.nms = cfg.task.nms
|
109 |
+
self.idx2label = cfg.class_list
|
110 |
self.save_path = save_path
|
111 |
|
112 |
def solve(self, dataloader: StreamDataLoader):
|
|
|
120 |
predict, _ = self.anchor2box(raw_output[0][3:], with_logits=True)
|
121 |
nms_out = bbox_nms(predict, self.nms)
|
122 |
draw_bboxes(
|
123 |
+
images[0],
|
124 |
+
nms_out[0],
|
125 |
+
scaled_bbox=False,
|
126 |
+
save_path=self.save_path,
|
127 |
+
save_name=f"frame{idx:03d}.png",
|
128 |
+
idx2label=self.idx2label,
|
129 |
)
|
130 |
except (KeyboardInterrupt, Exception) as e:
|
131 |
dataloader.stop_event.set()
|
yolo/utils/bounding_box_utils.py
CHANGED
@@ -307,7 +307,7 @@ def bbox_nms(predicts: Tensor, nms_cfg: NMSConfig):
|
|
307 |
instance_idx = nms_idx[idx == batch_idx[nms_idx]]
|
308 |
|
309 |
predict_nms = torch.cat(
|
310 |
-
[valid_cls[instance_idx][:, None], valid_con[instance_idx][:, None]
|
311 |
)
|
312 |
|
313 |
predicts_nms.append(predict_nms)
|
@@ -322,7 +322,7 @@ def calculate_map(predictions, ground_truths, iou_thresholds):
|
|
322 |
ground_truths = ground_truths[:n_gts]
|
323 |
aps = []
|
324 |
|
325 |
-
ious = calculate_iou(predictions[:,
|
326 |
|
327 |
for threshold in iou_thresholds:
|
328 |
tp = torch.zeros(n_preds, device=device)
|
|
|
307 |
instance_idx = nms_idx[idx == batch_idx[nms_idx]]
|
308 |
|
309 |
predict_nms = torch.cat(
|
310 |
+
[valid_cls[instance_idx][:, None], valid_box[instance_idx], valid_con[instance_idx][:, None]], dim=-1
|
311 |
)
|
312 |
|
313 |
predicts_nms.append(predict_nms)
|
|
|
322 |
ground_truths = ground_truths[:n_gts]
|
323 |
aps = []
|
324 |
|
325 |
+
ious = calculate_iou(predictions[:, 1:-1], ground_truths[:, 1:]) # [n_preds, n_gts]
|
326 |
|
327 |
for threshold in iou_thresholds:
|
328 |
tp = torch.zeros(n_preds, device=device)
|