lucytuan commited on
Commit
4ebaf9d
·
2 Parent(s): e3d53d5 23db031

Merge branch 'DATASET' of https://github.com/WongKinYiu/yolov9mit into DATASET

Browse files
.github/workflows/main.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: YOLOv9 - Model test
2
+
3
+ on:
4
+ push:
5
+ branches: [ main ]
6
+ pull_request:
7
+ branches: [ main ]
8
+
9
+ jobs:
10
+ build:
11
+
12
+ runs-on: ubuntu-latest
13
+
14
+ steps:
15
+ - uses: actions/checkout@v2
16
+ - name: Set up Python 3.8
17
+ uses: actions/setup-python@v2
18
+ with:
19
+ python-version: 3.8
20
+ - name: Install dependencies
21
+ run: |
22
+ python -m pip install --upgrade pip
23
+ pip install -r requirements.txt
24
+ - name: Test with pytest
25
+ run: |
26
+ pytest
.pre-commit-config.yaml CHANGED
@@ -6,3 +6,9 @@ repos:
6
  language_version: python3 # Specify the Python version
7
  exclude: '.*\.yaml$' # Regex pattern to exclude all YAML files
8
  args: ["--line-length", "120"] # Set max line length to 100 characters
 
 
 
 
 
 
 
6
  language_version: python3 # Specify the Python version
7
  exclude: '.*\.yaml$' # Regex pattern to exclude all YAML files
8
  args: ["--line-length", "120"] # Set max line length to 100 characters
9
+
10
+ - repo: https://github.com/pre-commit/mirrors-isort
11
+ rev: v5.10.1 # Use the appropriate version or "stable" for the latest stable release
12
+ hooks:
13
+ - id: isort
14
+ args: ["--profile", "black", "--verbose"]
README.md CHANGED
@@ -1,6 +1,20 @@
1
  # YOLOv9-MIT
2
  An MIT license rewrite of YOLOv9
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  ## To-Do Lists
5
  - [ ] Project Setup
6
  - [X] requirements
@@ -16,7 +30,7 @@ An MIT license rewrite of YOLOv9
16
  - [ ] Auto Download
17
  - [ ] xywh, xxyy, xcyc
18
  - [ ] Dataloder
19
- - [ ] Data arugment
20
  - [ ] Model
21
  - [ ] load model
22
  - [ ] from yaml
 
1
  # YOLOv9-MIT
2
  An MIT license rewrite of YOLOv9
3
 
4
+ ![WIP](https://img.shields.io/badge/status-WIP-orange)
5
+ > [!IMPORTANT]
6
+ > This project is currently a Work In Progress and may undergo significant changes. It is not recommended for use in production environments until further notice. Please check back regularly for updates.
7
+ >
8
+ > Use of this code is at your own risk and discretion. It is advisable to consult with the project owner before deploying or integrating into any critical systems.
9
+
10
+ ## Contributing
11
+
12
+ While the project's structure is still being finalized, we ask that potential contributors wait for these foundational decisions to be made. We greatly appreciate your patience and are excited to welcome contributions from the community once we are ready. Alternatively, you are welcome to propose functions that should be implemented based on the original YOLO version or suggest other enhancements!
13
+
14
+ If you are interested in contributing, please keep an eye on project updates or contact us directly at [[email protected]](mailto:[email protected]) for more information.
15
+
16
+
17
+
18
  ## To-Do Lists
19
  - [ ] Project Setup
20
  - [X] requirements
 
30
  - [ ] Auto Download
31
  - [ ] xywh, xxyy, xcyc
32
  - [ ] Dataloder
33
+ - [ ] Data augment
34
  - [ ] Model
35
  - [ ] load model
36
  - [ ] from yaml
config/config.py CHANGED
@@ -1,5 +1,5 @@
1
  from dataclasses import dataclass
2
- from typing import List, Dict, Union
3
 
4
 
5
  @dataclass
 
1
  from dataclasses import dataclass
2
+ from typing import Dict, List, Union
3
 
4
 
5
  @dataclass
config/config.yaml CHANGED
@@ -7,4 +7,5 @@ defaults:
7
  - download: ../data/download
8
  - augmentation: ../data/augmentation
9
  - model: v7-base
 
10
  - _self_
 
7
  - download: ../data/download
8
  - augmentation: ../data/augmentation
9
  - model: v7-base
10
+ - hyper: default
11
  - _self_
config/data/augmentation.yaml CHANGED
@@ -1,3 +1,3 @@
1
  Mosaic: 1
2
- MixUp: 1
3
- RandomHorizontalFlip: 0.5
 
1
  Mosaic: 1
2
+ # MixUp: 1
3
+ HorizontalFlip: 0.5
config/hyper/default.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ data:
2
+ batch_size: 4
3
+ shuffle: True
4
+ num_workers: 4
5
+ pin_memory: True
config/model/v7-base.yaml CHANGED
@@ -241,3 +241,4 @@ model:
241
  - [36,75, 76,55, 72,146] # P4/16
242
  - [142,110, 192,243, 459,401] # P5/32
243
  source: [102, 103, 104]
 
 
241
  - [36,75, 76,55, 72,146] # P4/16
242
  - [142,110, 192,243, 459,401] # P5/32
243
  source: [102, 103, 104]
244
+ output: True
model/module.py CHANGED
@@ -11,10 +11,10 @@ class Conv(nn.Module):
11
  out_channels,
12
  kernel_size,
13
  stride=1,
14
- padding=0,
15
  dilation=1,
16
  groups=1,
17
- act=nn.ReLU(),
18
  bias=False,
19
  auto_padding=True,
20
  padding_mode="zeros",
@@ -48,10 +48,12 @@ class Conv(nn.Module):
48
  # RepVGG
49
  class RepConv(nn.Module):
50
  # https://github.com/DingXiaoH/RepVGG
51
- def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, groups=1, act=nn.ReLU()):
 
 
52
 
53
  super().__init__()
54
-
55
  self.conv1 = Conv(in_channels, out_channels, kernel_size, stride, groups=groups, act=False)
56
  self.conv2 = Conv(in_channels, out_channels, 1, stride, groups=groups, act=False)
57
  self.act = act if isinstance(act, nn.Module) else nn.Identity()
@@ -64,6 +66,30 @@ class RepConv(nn.Module):
64
 
65
  # to be implement
66
  # def fuse_convs(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
 
69
  # ResNet
 
11
  out_channels,
12
  kernel_size,
13
  stride=1,
14
+ padding=None,
15
  dilation=1,
16
  groups=1,
17
+ act=nn.SiLU(),
18
  bias=False,
19
  auto_padding=True,
20
  padding_mode="zeros",
 
48
  # RepVGG
49
  class RepConv(nn.Module):
50
  # https://github.com/DingXiaoH/RepVGG
51
+ def __init__(
52
+ self, in_channels, out_channels, kernel_size=3, padding=None, stride=1, groups=1, act=nn.SiLU(), deploy=False
53
+ ):
54
 
55
  super().__init__()
56
+ self.deploy = deploy
57
  self.conv1 = Conv(in_channels, out_channels, kernel_size, stride, groups=groups, act=False)
58
  self.conv2 = Conv(in_channels, out_channels, 1, stride, groups=groups, act=False)
59
  self.act = act if isinstance(act, nn.Module) else nn.Identity()
 
66
 
67
  # to be implement
68
  # def fuse_convs(self):
69
+ def fuse_conv_bn(self, conv, bn):
70
+
71
+ std = (bn.running_var + bn.eps).sqrt()
72
+ bias = bn.bias - bn.running_mean * bn.weight / std
73
+
74
+ t = (bn.weight / std).reshape(-1, 1, 1, 1)
75
+ weights = conv.weight * t
76
+
77
+ bn = nn.Identity()
78
+ conv = nn.Conv2d(
79
+ in_channels=conv.in_channels,
80
+ out_channels=conv.out_channels,
81
+ kernel_size=conv.kernel_size,
82
+ stride=conv.stride,
83
+ padding=conv.padding,
84
+ dilation=conv.dilation,
85
+ groups=conv.groups,
86
+ bias=True,
87
+ padding_mode=conv.padding_mode,
88
+ )
89
+
90
+ conv.weight = torch.nn.Parameter(weights)
91
+ conv.bias = torch.nn.Parameter(bias)
92
+ return conv
93
 
94
 
95
  # ResNet
model/yolo.py CHANGED
@@ -4,6 +4,7 @@ import torch
4
  import torch.nn as nn
5
  from loguru import logger
6
  from omegaconf import OmegaConf
 
7
  from tools.layer_helper import get_layer_map
8
 
9
 
@@ -32,6 +33,7 @@ class YOLO(nn.Module):
32
  layer_type, layer_info = next(iter(layer_spec.items()))
33
  layer_args = layer_info.get("args", {})
34
  source = layer_info.get("source", -1)
 
35
 
36
  if isinstance(source, str):
37
  source = layer_indices_by_tag[source]
@@ -41,7 +43,7 @@ class YOLO(nn.Module):
41
  layer_args["nc"] = self.nc
42
  layer_args["ch"] = [output_dim[idx] for idx in source]
43
 
44
- layer = self.create_layer(layer_type, source, **layer_args)
45
  model_list.append(layer)
46
 
47
  if "tags" in layer_info:
@@ -55,6 +57,7 @@ class YOLO(nn.Module):
55
 
56
  def forward(self, x):
57
  y = [x]
 
58
  for layer in self.model:
59
  if OmegaConf.is_list(layer.source):
60
  model_input = [y[idx] for idx in layer.source]
@@ -62,7 +65,9 @@ class YOLO(nn.Module):
62
  model_input = y[layer.source]
63
  x = layer(model_input)
64
  y.append(x)
65
- return x
 
 
66
 
67
  def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
68
  if "Conv" in layer_type:
@@ -74,10 +79,11 @@ class YOLO(nn.Module):
74
  if layer_type == "IDetect":
75
  return None
76
 
77
- def create_layer(self, layer_type: str, source: Union[int, list], **kwargs):
78
  if layer_type in self.layer_map:
79
  layer = self.layer_map[layer_type](**kwargs)
80
  layer.source = source
 
81
  return layer
82
  else:
83
  raise ValueError(f"Unsupported layer type: {layer_type}")
 
4
  import torch.nn as nn
5
  from loguru import logger
6
  from omegaconf import OmegaConf
7
+
8
  from tools.layer_helper import get_layer_map
9
 
10
 
 
33
  layer_type, layer_info = next(iter(layer_spec.items()))
34
  layer_args = layer_info.get("args", {})
35
  source = layer_info.get("source", -1)
36
+ output = layer_info.get("output", False)
37
 
38
  if isinstance(source, str):
39
  source = layer_indices_by_tag[source]
 
43
  layer_args["nc"] = self.nc
44
  layer_args["ch"] = [output_dim[idx] for idx in source]
45
 
46
+ layer = self.create_layer(layer_type, source, output, **layer_args)
47
  model_list.append(layer)
48
 
49
  if "tags" in layer_info:
 
57
 
58
  def forward(self, x):
59
  y = [x]
60
+ output = []
61
  for layer in self.model:
62
  if OmegaConf.is_list(layer.source):
63
  model_input = [y[idx] for idx in layer.source]
 
65
  model_input = y[layer.source]
66
  x = layer(model_input)
67
  y.append(x)
68
+ if layer.output:
69
+ output.append(x)
70
+ return output
71
 
72
  def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
73
  if "Conv" in layer_type:
 
79
  if layer_type == "IDetect":
80
  return None
81
 
82
+ def create_layer(self, layer_type: str, source: Union[int, list], output=False, **kwargs):
83
  if layer_type in self.layer_map:
84
  layer = self.layer_map[layer_type](**kwargs)
85
  layer.source = source
86
+ layer.output = output
87
  return layer
88
  else:
89
  raise ValueError(f"Unsupported layer type: {layer_type}")
requirements.txt CHANGED
@@ -1,11 +1,13 @@
1
  hydra-core
2
  loguru
3
  numpy
 
4
  pytest
5
  pyyaml
6
  requests
7
  rich
8
  torch
 
9
  tqdm
10
  Pillow
11
  diskcache
 
1
  hydra-core
2
  loguru
3
  numpy
4
+ Pillow
5
  pytest
6
  pyyaml
7
  requests
8
  rich
9
  torch
10
+ torchvision
11
  tqdm
12
  Pillow
13
  diskcache
tests/test_model/test_yolo.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import pytest
4
+ import torch
5
+ from hydra import compose, initialize
6
+ from hydra.core.global_hydra import GlobalHydra
7
+ from omegaconf import DictConfig, OmegaConf
8
+
9
+ sys.path.append("./")
10
+ from model.yolo import YOLO, get_model
11
+
12
+ config_path = "../../config/model"
13
+ config_name = "v7-base"
14
+
15
+
16
+ def test_build_model():
17
+
18
+ with initialize(config_path=config_path, version_base=None):
19
+ model_cfg = compose(config_name=config_name)
20
+ OmegaConf.set_struct(model_cfg, False)
21
+ model = YOLO(model_cfg)
22
+ model.build_model(model_cfg.model)
23
+ assert len(model.model) == 106
24
+
25
+
26
+ def test_get_model():
27
+ with initialize(config_path=config_path, version_base=None):
28
+ model_cfg = compose(config_name=config_name)
29
+ model = get_model(model_cfg)
30
+ assert isinstance(model, YOLO)
31
+
32
+
33
+ def test_yolo_forward_output_shape():
34
+ with initialize(config_path=config_path, version_base=None):
35
+ model_cfg = compose(config_name=config_name)
36
+
37
+ model = get_model(model_cfg)
38
+ # 2 - batch size, 3 - number of channels, 640x640 - image dimensions
39
+ dummy_input = torch.rand(2, 3, 640, 640)
40
+
41
+ # Forward pass through the model
42
+ output = model(dummy_input)
43
+ output_shape = [x.shape for x in output[-1]]
44
+ assert output_shape == [
45
+ torch.Size([2, 3, 20, 20, 85]),
46
+ torch.Size([2, 3, 80, 80, 85]),
47
+ torch.Size([2, 3, 40, 40, 85]),
48
+ ]
tests/test_utils/test_dataaugment.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import pytest
4
+ import torch
5
+ from PIL import Image
6
+ from torchvision.transforms import functional as TF
7
+
8
+ sys.path.append("./")
9
+ from utils.data_augment import Compose, Mosaic, RandomHorizontalFlip
10
+
11
+
12
+ def test_random_horizontal_flip():
13
+ # Create a mock image and bounding boxes
14
+ img = Image.new("RGB", (100, 100), color="red")
15
+ boxes = torch.tensor([[1, 0.1, 0.1, 0.9, 0.9]]) # class, xmin, ymin, xmax, ymax
16
+
17
+ flip_transform = RandomHorizontalFlip(prob=1) # Set probability to 1 to ensure flip
18
+ flipped_img, flipped_boxes = flip_transform(img, boxes)
19
+
20
+ # Assert image is flipped by comparing it to a manually flipped image
21
+ assert TF.hflip(img) == flipped_img
22
+
23
+ # Assert bounding boxes are flipped correctly
24
+ expected_boxes = torch.tensor([[1, 0.1, 0.1, 0.9, 0.9]])
25
+ expected_boxes[:, [1, 3]] = 1 - expected_boxes[:, [3, 1]]
26
+ assert torch.allclose(flipped_boxes, expected_boxes), "Bounding boxes were not flipped correctly"
27
+
28
+
29
+ def test_compose():
30
+ # Define two mock transforms that simply return the inputs
31
+ def mock_transform(image, boxes):
32
+ return image, boxes
33
+
34
+ compose = Compose([mock_transform, mock_transform])
35
+ img = Image.new("RGB", (10, 10), color="blue")
36
+ boxes = torch.tensor([[0, 0.2, 0.2, 0.8, 0.8]])
37
+
38
+ transformed_img, transformed_boxes = compose(img, boxes)
39
+
40
+ assert transformed_img == img, "Image should not be altered"
41
+ assert torch.equal(transformed_boxes, boxes), "Boxes should not be altered"
42
+
43
+
44
+ def test_mosaic():
45
+ img = Image.new("RGB", (100, 100), color="green")
46
+ boxes = torch.tensor([[0, 0.25, 0.25, 0.75, 0.75]])
47
+
48
+ # Mock parent with image_size and get_more_data method
49
+ class MockParent:
50
+ image_size = 100
51
+
52
+ def get_more_data(self, num_images):
53
+ return [(img, boxes) for _ in range(num_images)]
54
+
55
+ mosaic = Mosaic(prob=1) # Ensure mosaic is applied
56
+ mosaic.set_parent(MockParent())
57
+
58
+ mosaic_img, mosaic_boxes = mosaic(img, boxes)
59
+
60
+ # Checks here would depend on the exact expected behavior of the mosaic function,
61
+ # such as dimensions and content of the output image and boxes.
62
+
63
+ assert mosaic_img.size == (200, 200), "Mosaic image size should be doubled"
64
+ assert len(mosaic_boxes) > 0, "Should have some bounding boxes"
tools/layer_helper.py CHANGED
@@ -1,5 +1,7 @@
1
  import inspect
 
2
  import torch.nn as nn
 
3
  from model import module
4
 
5
 
 
1
  import inspect
2
+
3
  import torch.nn as nn
4
+
5
  from model import module
6
 
7
 
tools/log_helper.py CHANGED
@@ -12,6 +12,7 @@ Example:
12
  """
13
 
14
  import sys
 
15
  from loguru import logger
16
 
17
 
 
12
  """
13
 
14
  import sys
15
+
16
  from loguru import logger
17
 
18
 
train.py CHANGED
@@ -1,13 +1,16 @@
 
1
  from loguru import logger
 
 
2
  from model.yolo import get_model
3
  from tools.log_helper import custom_logger
 
4
  from utils.get_dataset import prepare_dataset
5
- import hydra
6
- from config.config import Config
7
 
8
 
9
  @hydra.main(config_path="config", config_name="config", version_base=None)
10
  def main(cfg: Config):
 
11
  if cfg.download.auto:
12
  prepare_dataset(cfg.download)
13
 
 
1
+ import hydra
2
  from loguru import logger
3
+
4
+ from config.config import Config
5
  from model.yolo import get_model
6
  from tools.log_helper import custom_logger
7
+ from utils.dataloader import YoloDataset
8
  from utils.get_dataset import prepare_dataset
 
 
9
 
10
 
11
  @hydra.main(config_path="config", config_name="config", version_base=None)
12
  def main(cfg: Config):
13
+ dataset = YoloDataset(cfg)
14
  if cfg.download.auto:
15
  prepare_dataset(cfg.download)
16
 
utils/data_augment.py CHANGED
@@ -1,15 +1,15 @@
1
- from PIL import Image
2
  import numpy as np
3
  import torch
 
4
  from torchvision.transforms import functional as TF
5
- from torchvision.transforms.functional import to_tensor, to_pil_image
6
 
7
 
8
  class Compose:
9
  """Composes several transforms together."""
10
 
11
- def __init__(self, transforms):
12
  self.transforms = transforms
 
13
 
14
  for transform in self.transforms:
15
  if hasattr(transform, "set_parent"):
@@ -20,11 +20,8 @@ class Compose:
20
  image, boxes = transform(image, boxes)
21
  return image, boxes
22
 
23
- def get_more_data(self):
24
- raise NotImplementedError("This method should be overridden by subclass instances!")
25
-
26
 
27
- class RandomHorizontalFlip:
28
  """Randomly horizontally flips the image along with the bounding boxes."""
29
 
30
  def __init__(self, prob=0.5):
@@ -37,7 +34,7 @@ class RandomHorizontalFlip:
37
  return image, boxes
38
 
39
 
40
- class RandomVerticalFlip:
41
  """Randomly vertically flips the image along with the bounding boxes."""
42
 
43
  def __init__(self, prob=0.5):
@@ -90,6 +87,7 @@ class Mosaic:
90
  all_labels.append(adjusted_boxes)
91
 
92
  all_labels = torch.cat(all_labels, dim=0)
 
93
  return mosaic_image, all_labels
94
 
95
 
@@ -118,10 +116,10 @@ class MixUp:
118
  lam = np.random.beta(self.alpha, self.alpha) if self.alpha > 0 else 0.5
119
 
120
  # Mix images
121
- image1, image2 = to_tensor(image), to_tensor(image2)
122
  mixed_image = lam * image1 + (1 - lam) * image2
123
 
124
  # Mix bounding boxes
125
  mixed_boxes = torch.cat([lam * boxes, (1 - lam) * boxes2])
126
 
127
- return to_pil_image(mixed_image), mixed_boxes
 
 
1
  import numpy as np
2
  import torch
3
+ from PIL import Image
4
  from torchvision.transforms import functional as TF
 
5
 
6
 
7
  class Compose:
8
  """Composes several transforms together."""
9
 
10
+ def __init__(self, transforms, image_size: int = 640):
11
  self.transforms = transforms
12
+ self.image_size = image_size
13
 
14
  for transform in self.transforms:
15
  if hasattr(transform, "set_parent"):
 
20
  image, boxes = transform(image, boxes)
21
  return image, boxes
22
 
 
 
 
23
 
24
+ class HorizontalFlip:
25
  """Randomly horizontally flips the image along with the bounding boxes."""
26
 
27
  def __init__(self, prob=0.5):
 
34
  return image, boxes
35
 
36
 
37
+ class VerticalFlip:
38
  """Randomly vertically flips the image along with the bounding boxes."""
39
 
40
  def __init__(self, prob=0.5):
 
87
  all_labels.append(adjusted_boxes)
88
 
89
  all_labels = torch.cat(all_labels, dim=0)
90
+ mosaic_image = mosaic_image.resize((img_sz, img_sz))
91
  return mosaic_image, all_labels
92
 
93
 
 
116
  lam = np.random.beta(self.alpha, self.alpha) if self.alpha > 0 else 0.5
117
 
118
  # Mix images
119
+ image1, image2 = TF.to_tensor(image), TF.to_tensor(image2)
120
  mixed_image = lam * image1 + (1 - lam) * image2
121
 
122
  # Mix bounding boxes
123
  mixed_boxes = torch.cat([lam * boxes, (1 - lam) * boxes2])
124
 
125
+ return TF.to_pil_image(mixed_image), mixed_boxes
utils/dataloader.py CHANGED
@@ -1,26 +1,30 @@
1
- from PIL import Image
2
- from os import path, listdir
3
 
 
4
  import hydra
5
  import numpy as np
6
  import torch
7
- from torch.utils.data import Dataset
8
  from loguru import logger
 
 
 
9
  from tqdm.rich import tqdm
10
- import diskcache as dc
11
- from typing import Union
12
- from drawer import draw_bboxes
13
- from data_augment import Compose, RandomHorizontalFlip, RandomVerticalFlip, Mosaic, MixUp
14
 
15
 
16
  class YoloDataset(Dataset):
17
- def __init__(self, dataset_cfg: dict, phase: str = "train", image_size: int = 640, transform=None):
 
 
18
  phase_name = dataset_cfg.get(phase, phase)
19
  self.image_size = image_size
20
 
21
- self.transform = transform
 
22
  self.transform.get_more_data = self.get_more_data
23
- self.transform.image_size = self.image_size
24
  self.data = self.load_data(dataset_cfg.path, phase_name)
25
 
26
  def load_data(self, dataset_path, phase_name):
@@ -121,17 +125,55 @@ class YoloDataset(Dataset):
121
  img, bboxes = self.get_data(idx)
122
  if self.transform:
123
  img, bboxes = self.transform(img, bboxes)
 
124
  return img, bboxes
125
 
126
  def __len__(self) -> int:
127
  return len(self.data)
128
 
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  @hydra.main(config_path="../config", config_name="config", version_base=None)
131
  def main(cfg):
132
- transform = Compose([eval(aug)(prob) for aug, prob in cfg.augmentation.items()])
133
- dataset = YoloDataset(cfg.data, transform=transform)
134
- draw_bboxes(*dataset[0])
135
 
136
 
137
  if __name__ == "__main__":
 
1
+ from os import listdir, path
2
+ from typing import List, Tuple, Union
3
 
4
+ import diskcache as dc
5
  import hydra
6
  import numpy as np
7
  import torch
 
8
  from loguru import logger
9
+ from PIL import Image
10
+ from torch.utils.data import DataLoader, Dataset
11
+ from torchvision.transforms import functional as TF
12
  from tqdm.rich import tqdm
13
+
14
+ from utils.data_augment import Compose, HorizontalFlip, MixUp, Mosaic, VerticalFlip
15
+ from utils.drawer import draw_bboxes
 
16
 
17
 
18
  class YoloDataset(Dataset):
19
+ def __init__(self, config: dict, phase: str = "train", image_size: int = 640):
20
+ dataset_cfg = config.data
21
+ augment_cfg = config.augmentation
22
  phase_name = dataset_cfg.get(phase, phase)
23
  self.image_size = image_size
24
 
25
+ transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
26
+ self.transform = Compose(transforms, self.image_size)
27
  self.transform.get_more_data = self.get_more_data
 
28
  self.data = self.load_data(dataset_cfg.path, phase_name)
29
 
30
  def load_data(self, dataset_path, phase_name):
 
125
  img, bboxes = self.get_data(idx)
126
  if self.transform:
127
  img, bboxes = self.transform(img, bboxes)
128
+ img = TF.to_tensor(img)
129
  return img, bboxes
130
 
131
  def __len__(self) -> int:
132
  return len(self.data)
133
 
134
 
135
+ class YoloDataLoader(DataLoader):
136
+ def __init__(self, config: dict):
137
+ """Initializes the YoloDataLoader with hydra-config files."""
138
+ hyper = config.hyper.data
139
+ dataset = YoloDataset(config)
140
+
141
+ super().__init__(
142
+ dataset,
143
+ batch_size=hyper.batch_size,
144
+ shuffle=hyper.shuffle,
145
+ num_workers=hyper.num_workers,
146
+ pin_memory=hyper.pin_memory,
147
+ collate_fn=self.collate_fn,
148
+ )
149
+
150
+ def collate_fn(self, batch: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
151
+ """
152
+ A collate function to handle batching of images and their corresponding targets.
153
+
154
+ Args:
155
+ batch (list of tuples): Each tuple contains:
156
+ - image (torch.Tensor): The image tensor.
157
+ - labels (torch.Tensor): The tensor of labels for the image.
158
+
159
+ Returns:
160
+ Tuple[torch.Tensor, List[torch.Tensor]]: A tuple containing:
161
+ - A tensor of batched images.
162
+ - A list of tensors, each corresponding to bboxes for each image in the batch.
163
+ """
164
+ images = torch.stack([item[0] for item in batch])
165
+ targets = [item[1] for item in batch]
166
+ return images, targets
167
+
168
+
169
+ def get_dataloader(config):
170
+ return YoloDataLoader(config)
171
+
172
+
173
  @hydra.main(config_path="../config", config_name="config", version_base=None)
174
  def main(cfg):
175
+ dataloader = get_dataloader(cfg)
176
+ draw_bboxes(next(iter(dataloader)))
 
177
 
178
 
179
  if __name__ == "__main__":
utils/drawer.py CHANGED
@@ -1,23 +1,31 @@
 
 
 
 
1
  from PIL import Image, ImageDraw, ImageFont
 
2
 
3
 
4
- def draw_bboxes(img, bboxes):
5
  """
6
  Draw bounding boxes on an image.
7
 
8
  Args:
9
- - image_path (str): Path to the image file.
10
- - bboxes (list of lists/tuples): Bounding boxes with [x_min, y_min, x_max, y_max, class_id].
 
11
  """
12
- # Load an image
13
- draw = ImageDraw.Draw(img)
 
 
 
 
 
14
 
15
- # Font for class_id (optional)
16
- try:
17
- font = ImageFont.truetype("arial.ttf", 30)
18
- except IOError:
19
- font = ImageFont.load_default(30)
20
  width, height = img.size
 
21
 
22
  for bbox in bboxes:
23
  class_id, x_min, y_min, x_max, y_max = bbox
@@ -26,7 +34,8 @@ def draw_bboxes(img, bboxes):
26
  y_min = y_min * height
27
  y_max = y_max * height
28
  shape = [(x_min, y_min), (x_max, y_max)]
29
- draw.rectangle(shape, outline="red", width=2)
30
  draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
31
 
32
- img.save("output.jpg")
 
 
1
+ from typing import List, Union
2
+
3
+ import torch
4
+ from loguru import logger
5
  from PIL import Image, ImageDraw, ImageFont
6
+ from torchvision.transforms.functional import to_pil_image
7
 
8
 
9
+ def draw_bboxes(img: Union[Image.Image, torch.Tensor], bboxes: List[List[Union[int, float]]]):
10
  """
11
  Draw bounding boxes on an image.
12
 
13
  Args:
14
+ - img (PIL Image or torch.Tensor): Image on which to draw the bounding boxes.
15
+ - bboxes (List of Lists/Tensors): Bounding boxes with [class_id, x_min, y_min, x_max, y_max],
16
+ where coordinates are normalized [0, 1].
17
  """
18
+ # Convert tensor image to PIL Image if necessary
19
+ if isinstance(img, torch.Tensor):
20
+ if img.dim() > 3:
21
+ logger.info("Multi-frame tensor detected, using the first image.")
22
+ img = img[0]
23
+ bboxes = bboxes[0]
24
+ img = to_pil_image(img)
25
 
26
+ draw = ImageDraw.Draw(img)
 
 
 
 
27
  width, height = img.size
28
+ font = ImageFont.load_default(30)
29
 
30
  for bbox in bboxes:
31
  class_id, x_min, y_min, x_max, y_max = bbox
 
34
  y_min = y_min * height
35
  y_max = y_max * height
36
  shape = [(x_min, y_min), (x_max, y_max)]
37
+ draw.rectangle(shape, outline="red", width=3)
38
  draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
39
 
40
+ img.save("visualize.jpg") # Save the image with annotations
41
+ logger.info("Saved visualize image at visualize.png")
utils/get_dataset.py CHANGED
@@ -2,8 +2,8 @@ import os
2
  import zipfile
3
 
4
  import hydra
5
- from loguru import logger
6
  import requests
 
7
  from tqdm.rich import tqdm
8
 
9
 
 
2
  import zipfile
3
 
4
  import hydra
 
5
  import requests
6
+ from loguru import logger
7
  from tqdm.rich import tqdm
8
 
9