henry000 commited on
Commit
64acfd1
Β·
1 Parent(s): c8710f3

🎨 [Update] drawer input, more robust/consistency

Browse files
Files changed (2) hide show
  1. yolo/tools/drawer.py +6 -4
  2. yolo/tools/solver.py +1 -1
yolo/tools/drawer.py CHANGED
@@ -27,15 +27,17 @@ def draw_bboxes(
27
  if isinstance(img, torch.Tensor):
28
  if img.dim() > 3:
29
  logger.warning("πŸ” >3 dimension tensor detected, using the 0-idx image.")
30
- img, bboxes = img[0], bboxes[0]
31
  img = to_pil_image(img)
32
 
 
 
33
  draw = ImageDraw.Draw(img, "RGBA")
34
 
35
  try:
36
- font = ImageFont.truetype("arial.ttf", 15)
37
  except IOError:
38
- font = ImageFont.load_default()
39
 
40
  for bbox in bboxes:
41
  class_id, x_min, y_min, x_max, y_max, *conf = [float(val) for val in bbox]
@@ -52,7 +54,7 @@ def draw_bboxes(
52
 
53
  text_bbox = font.getbbox(label_text)
54
  text_width = text_bbox[2] - text_bbox[0]
55
- text_height = (text_bbox[3] - text_bbox[1]) * 1.25
56
 
57
  text_background = [(x_min, y_min), (x_min + text_width, y_min + text_height)]
58
  draw.rounded_rectangle(text_background, fill=(*color_map, 175), radius=2)
 
27
  if isinstance(img, torch.Tensor):
28
  if img.dim() > 3:
29
  logger.warning("πŸ” >3 dimension tensor detected, using the 0-idx image.")
30
+ img = img[0]
31
  img = to_pil_image(img)
32
 
33
+ img, bboxes = img.copy(), bboxes[0]
34
+ label_size = img.size[1] / 30
35
  draw = ImageDraw.Draw(img, "RGBA")
36
 
37
  try:
38
+ font = ImageFont.truetype("arial.ttf", label_size)
39
  except IOError:
40
+ font = ImageFont.load_default(label_size)
41
 
42
  for bbox in bboxes:
43
  class_id, x_min, y_min, x_max, y_max, *conf = [float(val) for val in bbox]
 
54
 
55
  text_bbox = font.getbbox(label_text)
56
  text_width = text_bbox[2] - text_bbox[0]
57
+ text_height = (text_bbox[3] - text_bbox[1]) * 1.5
58
 
59
  text_background = [(x_min, y_min), (x_min + text_width, y_min + text_height)]
60
  draw.rounded_rectangle(text_background, fill=(*color_map, 175), radius=2)
yolo/tools/solver.py CHANGED
@@ -142,7 +142,7 @@ class ModelTester:
142
  predicts = self.model(images)
143
  predicts = self.vec2box(predicts["Main"])
144
  nms_out = bbox_nms(predicts[0], predicts[2], self.nms)
145
- img = draw_bboxes(images[0], nms_out[0], idx2label=self.idx2label)
146
 
147
  if dataloader.is_stream:
148
  img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
 
142
  predicts = self.model(images)
143
  predicts = self.vec2box(predicts["Main"])
144
  nms_out = bbox_nms(predicts[0], predicts[2], self.nms)
145
+ img = draw_bboxes(images, nms_out, idx2label=self.idx2label)
146
 
147
  if dataloader.is_stream:
148
  img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)