henry000 commited on
Commit
597f02f
Β·
2 Parent(s): 306fc38 7967aab

πŸ”€ [Merge] branch 'SETUP' into INFERENCE

Browse files
yolo/tools/data_augmentation.py CHANGED
@@ -10,6 +10,7 @@ class AugmentationComposer:
10
  def __init__(self, transforms, image_size: int = 640):
11
  self.transforms = transforms
12
  self.image_size = image_size
 
13
 
14
  for transform in self.transforms:
15
  if hasattr(transform, "set_parent"):
@@ -18,9 +19,33 @@ class AugmentationComposer:
18
  def __call__(self, image, boxes):
19
  for transform in self.transforms:
20
  image, boxes = transform(image, boxes)
 
21
  return image, boxes
22
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  class HorizontalFlip:
25
  """Randomly horizontally flips the image along with the bounding boxes."""
26
 
 
10
  def __init__(self, transforms, image_size: int = 640):
11
  self.transforms = transforms
12
  self.image_size = image_size
13
+ self.pad_resize = PadAndResize(self.image_size)
14
 
15
  for transform in self.transforms:
16
  if hasattr(transform, "set_parent"):
 
19
  def __call__(self, image, boxes):
20
  for transform in self.transforms:
21
  image, boxes = transform(image, boxes)
22
+ image, boxes = self.pad_resize(image, boxes)
23
  return image, boxes
24
 
25
 
26
+ class PadAndResize:
27
+ def __init__(self, image_size):
28
+ """Initialize the object with the target image size."""
29
+ self.image_size = image_size
30
+
31
+ def __call__(self, image, boxes):
32
+ original_size = max(image.size)
33
+ scale = self.image_size / original_size
34
+ square_img = Image.new("RGB", (original_size, original_size), (255, 255, 255))
35
+ left = (original_size - image.width) // 2
36
+ top = (original_size - image.height) // 2
37
+ square_img.paste(image, (left, top))
38
+
39
+ resized_img = square_img.resize((self.image_size, self.image_size))
40
+
41
+ boxes[:, 1] = (boxes[:, 1] + left) * scale
42
+ boxes[:, 2] = (boxes[:, 2] + top) * scale
43
+ boxes[:, 3] = (boxes[:, 3] + left) * scale
44
+ boxes[:, 4] = (boxes[:, 4] + top) * scale
45
+
46
+ return resized_img, boxes
47
+
48
+
49
  class HorizontalFlip:
50
  """Randomly horizontally flips the image along with the bounding boxes."""
51
 
yolo/tools/data_loader.py CHANGED
@@ -32,8 +32,7 @@ from yolo.utils.dataset_utils import (
32
  class YoloDataset(Dataset):
33
  def __init__(self, config: TrainConfig, phase: str = "train2017", image_size: int = 640):
34
  augment_cfg = config.data.data_augment
35
- # TODO: add yaml -> train: train2017
36
- phase_name = config.dataset.auto_download.get(phase, phase)
37
  self.image_size = image_size
38
 
39
  transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
@@ -102,13 +101,14 @@ class YoloDataset(Dataset):
102
  continue
103
  with open(label_path, "r") as file:
104
  image_seg_annotations = [list(map(float, line.strip().split())) for line in file]
 
 
105
 
106
  labels = self.load_valid_labels(image_id, image_seg_annotations)
107
- if labels is not None:
108
- img_path = path.join(images_path, image_name)
109
- data.append((img_path, labels))
110
- valid_inputs += 1
111
 
 
 
 
112
  logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
113
  return data
114
 
@@ -135,7 +135,7 @@ class YoloDataset(Dataset):
135
  return torch.stack(bboxes)
136
  else:
137
  logger.warning("No valid BBox in {}", label_path)
138
- return None
139
 
140
  def get_data(self, idx):
141
  img_path, bboxes = self.data[idx]
@@ -161,7 +161,7 @@ class YoloDataLoader(DataLoader):
161
  def __init__(self, config: Config):
162
  """Initializes the YoloDataLoader with hydra-config files."""
163
  data_cfg = config.task.data
164
- dataset = YoloDataset(config.task)
165
 
166
  super().__init__(
167
  dataset,
 
32
  class YoloDataset(Dataset):
33
  def __init__(self, config: TrainConfig, phase: str = "train2017", image_size: int = 640):
34
  augment_cfg = config.data.data_augment
35
+ phase_name = config.dataset.get(phase, phase)
 
36
  self.image_size = image_size
37
 
38
  transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
 
101
  continue
102
  with open(label_path, "r") as file:
103
  image_seg_annotations = [list(map(float, line.strip().split())) for line in file]
104
+ else:
105
+ image_seg_annotations = []
106
 
107
  labels = self.load_valid_labels(image_id, image_seg_annotations)
 
 
 
 
108
 
109
+ img_path = path.join(images_path, image_name)
110
+ data.append((img_path, labels))
111
+ valid_inputs += 1
112
  logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
113
  return data
114
 
 
135
  return torch.stack(bboxes)
136
  else:
137
  logger.warning("No valid BBox in {}", label_path)
138
+ return torch.zeros((0, 5))
139
 
140
  def get_data(self, idx):
141
  img_path, bboxes = self.data[idx]
 
161
  def __init__(self, config: Config):
162
  """Initializes the YoloDataLoader with hydra-config files."""
163
  data_cfg = config.task.data
164
+ dataset = YoloDataset(config.task, config.task.task)
165
 
166
  super().__init__(
167
  dataset,
yolo/tools/drawer.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import List, Union
2
 
3
  import numpy as np
@@ -8,7 +9,11 @@ from torchvision.transforms.functional import to_pil_image
8
 
9
 
10
  def draw_bboxes(
11
- img: Union[Image.Image, torch.Tensor], bboxes: List[List[Union[int, float]]], *, scaled_bbox: bool = True
 
 
 
 
12
  ):
13
  """
14
  Draw bounding boxes on an image.
@@ -21,7 +26,7 @@ def draw_bboxes(
21
  # Convert tensor image to PIL Image if necessary
22
  if isinstance(img, torch.Tensor):
23
  if img.dim() > 3:
24
- logger.info("Multi-frame tensor detected, using the first image.")
25
  img = img[0]
26
  bboxes = bboxes[0]
27
  img = to_pil_image(img)
@@ -41,8 +46,9 @@ def draw_bboxes(
41
  draw.rectangle(shape, outline="red", width=3)
42
  draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
43
 
44
- img.save("visualize.jpg") # Save the image with annotations
45
- logger.info("Saved visualize image at visualize.png")
 
46
  return img
47
 
48
 
 
1
+ import os
2
  from typing import List, Union
3
 
4
  import numpy as np
 
9
 
10
 
11
  def draw_bboxes(
12
+ img: Union[Image.Image, torch.Tensor],
13
+ bboxes: List[List[Union[int, float]]],
14
+ *,
15
+ scaled_bbox: bool = True,
16
+ save_path: str = "",
17
  ):
18
  """
19
  Draw bounding boxes on an image.
 
26
  # Convert tensor image to PIL Image if necessary
27
  if isinstance(img, torch.Tensor):
28
  if img.dim() > 3:
29
+ logger.warning("πŸ” Multi-frame tensor detected, using the first image.")
30
  img = img[0]
31
  bboxes = bboxes[0]
32
  img = to_pil_image(img)
 
46
  draw.rectangle(shape, outline="red", width=3)
47
  draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
48
 
49
+ save_image_path = os.path.join(save_path, "visualize.png")
50
+ img.save(save_image_path) # Save the image with annotations
51
+ logger.info(f"πŸ’Ύ Saved visualize image at {save_image_path}")
52
  return img
53
 
54
 
yolo/utils/dataset_utils.py CHANGED
@@ -5,6 +5,7 @@ from os import path
5
  from typing import Any, Dict, List, Optional, Tuple
6
 
7
  import numpy as np
 
8
 
9
  from yolo.tools.data_conversion import discretize_categories
10
 
@@ -32,7 +33,8 @@ def locate_label_paths(dataset_path: str, phase_name: str):
32
  if txt_files:
33
  return txt_labels_path, "txt"
34
 
35
- raise FileNotFoundError("No labels found in the specified dataset path and phase name.")
 
36
 
37
 
38
  def create_image_metadata(labels_path: str) -> Tuple[Dict[str, List], Dict[str, Dict]]:
 
5
  from typing import Any, Dict, List, Optional, Tuple
6
 
7
  import numpy as np
8
+ from loguru import logger
9
 
10
  from yolo.tools.data_conversion import discretize_categories
11
 
 
33
  if txt_files:
34
  return txt_labels_path, "txt"
35
 
36
+ logger.warning("No labels found in the specified dataset path and phase name.")
37
+ return [], None
38
 
39
 
40
  def create_image_metadata(labels_path: str) -> Tuple[Dict[str, List], Dict[str, Dict]]: