Merge branch 'DATASET' of https://github.com/WongKinYiu/yolov9mit into DATASET
Browse files- .github/workflows/main.yaml +26 -0
- .pre-commit-config.yaml +6 -0
- README.md +15 -1
- config/config.py +1 -1
- config/config.yaml +1 -0
- config/data/augmentation.yaml +2 -2
- config/hyper/default.yaml +5 -0
- config/model/v7-base.yaml +1 -0
- model/module.py +30 -4
- model/yolo.py +9 -3
- requirements.txt +2 -0
- tests/test_model/test_yolo.py +48 -0
- tests/test_utils/test_dataaugment.py +64 -0
- tools/layer_helper.py +2 -0
- tools/log_helper.py +1 -0
- train.py +5 -2
- utils/data_augment.py +8 -10
- utils/dataloader.py +55 -13
- utils/drawer.py +21 -12
- utils/get_dataset.py +1 -1
.github/workflows/main.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: YOLOv9 - Model test
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches: [ main ]
|
6 |
+
pull_request:
|
7 |
+
branches: [ main ]
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
build:
|
11 |
+
|
12 |
+
runs-on: ubuntu-latest
|
13 |
+
|
14 |
+
steps:
|
15 |
+
- uses: actions/checkout@v2
|
16 |
+
- name: Set up Python 3.8
|
17 |
+
uses: actions/setup-python@v2
|
18 |
+
with:
|
19 |
+
python-version: 3.8
|
20 |
+
- name: Install dependencies
|
21 |
+
run: |
|
22 |
+
python -m pip install --upgrade pip
|
23 |
+
pip install -r requirements.txt
|
24 |
+
- name: Test with pytest
|
25 |
+
run: |
|
26 |
+
pytest
|
.pre-commit-config.yaml
CHANGED
@@ -6,3 +6,9 @@ repos:
|
|
6 |
language_version: python3 # Specify the Python version
|
7 |
exclude: '.*\.yaml$' # Regex pattern to exclude all YAML files
|
8 |
args: ["--line-length", "120"] # Set max line length to 100 characters
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
language_version: python3 # Specify the Python version
|
7 |
exclude: '.*\.yaml$' # Regex pattern to exclude all YAML files
|
8 |
args: ["--line-length", "120"] # Set max line length to 100 characters
|
9 |
+
|
10 |
+
- repo: https://github.com/pre-commit/mirrors-isort
|
11 |
+
rev: v5.10.1 # Use the appropriate version or "stable" for the latest stable release
|
12 |
+
hooks:
|
13 |
+
- id: isort
|
14 |
+
args: ["--profile", "black", "--verbose"]
|
README.md
CHANGED
@@ -1,6 +1,20 @@
|
|
1 |
# YOLOv9-MIT
|
2 |
An MIT license rewrite of YOLOv9
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
## To-Do Lists
|
5 |
- [ ] Project Setup
|
6 |
- [X] requirements
|
@@ -16,7 +30,7 @@ An MIT license rewrite of YOLOv9
|
|
16 |
- [ ] Auto Download
|
17 |
- [ ] xywh, xxyy, xcyc
|
18 |
- [ ] Dataloder
|
19 |
-
- [ ] Data
|
20 |
- [ ] Model
|
21 |
- [ ] load model
|
22 |
- [ ] from yaml
|
|
|
1 |
# YOLOv9-MIT
|
2 |
An MIT license rewrite of YOLOv9
|
3 |
|
4 |
+

|
5 |
+
> [!IMPORTANT]
|
6 |
+
> This project is currently a Work In Progress and may undergo significant changes. It is not recommended for use in production environments until further notice. Please check back regularly for updates.
|
7 |
+
>
|
8 |
+
> Use of this code is at your own risk and discretion. It is advisable to consult with the project owner before deploying or integrating into any critical systems.
|
9 |
+
|
10 |
+
## Contributing
|
11 |
+
|
12 |
+
While the project's structure is still being finalized, we ask that potential contributors wait for these foundational decisions to be made. We greatly appreciate your patience and are excited to welcome contributions from the community once we are ready. Alternatively, you are welcome to propose functions that should be implemented based on the original YOLO version or suggest other enhancements!
|
13 |
+
|
14 |
+
If you are interested in contributing, please keep an eye on project updates or contact us directly at [[email protected]](mailto:[email protected]) for more information.
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
## To-Do Lists
|
19 |
- [ ] Project Setup
|
20 |
- [X] requirements
|
|
|
30 |
- [ ] Auto Download
|
31 |
- [ ] xywh, xxyy, xcyc
|
32 |
- [ ] Dataloder
|
33 |
+
- [ ] Data augment
|
34 |
- [ ] Model
|
35 |
- [ ] load model
|
36 |
- [ ] from yaml
|
config/config.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from dataclasses import dataclass
|
2 |
-
from typing import
|
3 |
|
4 |
|
5 |
@dataclass
|
|
|
1 |
from dataclasses import dataclass
|
2 |
+
from typing import Dict, List, Union
|
3 |
|
4 |
|
5 |
@dataclass
|
config/config.yaml
CHANGED
@@ -7,4 +7,5 @@ defaults:
|
|
7 |
- download: ../data/download
|
8 |
- augmentation: ../data/augmentation
|
9 |
- model: v7-base
|
|
|
10 |
- _self_
|
|
|
7 |
- download: ../data/download
|
8 |
- augmentation: ../data/augmentation
|
9 |
- model: v7-base
|
10 |
+
- hyper: default
|
11 |
- _self_
|
config/data/augmentation.yaml
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
Mosaic: 1
|
2 |
-
MixUp: 1
|
3 |
-
|
|
|
1 |
Mosaic: 1
|
2 |
+
# MixUp: 1
|
3 |
+
HorizontalFlip: 0.5
|
config/hyper/default.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
batch_size: 4
|
3 |
+
shuffle: True
|
4 |
+
num_workers: 4
|
5 |
+
pin_memory: True
|
config/model/v7-base.yaml
CHANGED
@@ -241,3 +241,4 @@ model:
|
|
241 |
- [36,75, 76,55, 72,146] # P4/16
|
242 |
- [142,110, 192,243, 459,401] # P5/32
|
243 |
source: [102, 103, 104]
|
|
|
|
241 |
- [36,75, 76,55, 72,146] # P4/16
|
242 |
- [142,110, 192,243, 459,401] # P5/32
|
243 |
source: [102, 103, 104]
|
244 |
+
output: True
|
model/module.py
CHANGED
@@ -11,10 +11,10 @@ class Conv(nn.Module):
|
|
11 |
out_channels,
|
12 |
kernel_size,
|
13 |
stride=1,
|
14 |
-
padding=
|
15 |
dilation=1,
|
16 |
groups=1,
|
17 |
-
act=nn.
|
18 |
bias=False,
|
19 |
auto_padding=True,
|
20 |
padding_mode="zeros",
|
@@ -48,10 +48,12 @@ class Conv(nn.Module):
|
|
48 |
# RepVGG
|
49 |
class RepConv(nn.Module):
|
50 |
# https://github.com/DingXiaoH/RepVGG
|
51 |
-
def __init__(
|
|
|
|
|
52 |
|
53 |
super().__init__()
|
54 |
-
|
55 |
self.conv1 = Conv(in_channels, out_channels, kernel_size, stride, groups=groups, act=False)
|
56 |
self.conv2 = Conv(in_channels, out_channels, 1, stride, groups=groups, act=False)
|
57 |
self.act = act if isinstance(act, nn.Module) else nn.Identity()
|
@@ -64,6 +66,30 @@ class RepConv(nn.Module):
|
|
64 |
|
65 |
# to be implement
|
66 |
# def fuse_convs(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
|
69 |
# ResNet
|
|
|
11 |
out_channels,
|
12 |
kernel_size,
|
13 |
stride=1,
|
14 |
+
padding=None,
|
15 |
dilation=1,
|
16 |
groups=1,
|
17 |
+
act=nn.SiLU(),
|
18 |
bias=False,
|
19 |
auto_padding=True,
|
20 |
padding_mode="zeros",
|
|
|
48 |
# RepVGG
|
49 |
class RepConv(nn.Module):
|
50 |
# https://github.com/DingXiaoH/RepVGG
|
51 |
+
def __init__(
|
52 |
+
self, in_channels, out_channels, kernel_size=3, padding=None, stride=1, groups=1, act=nn.SiLU(), deploy=False
|
53 |
+
):
|
54 |
|
55 |
super().__init__()
|
56 |
+
self.deploy = deploy
|
57 |
self.conv1 = Conv(in_channels, out_channels, kernel_size, stride, groups=groups, act=False)
|
58 |
self.conv2 = Conv(in_channels, out_channels, 1, stride, groups=groups, act=False)
|
59 |
self.act = act if isinstance(act, nn.Module) else nn.Identity()
|
|
|
66 |
|
67 |
# to be implement
|
68 |
# def fuse_convs(self):
|
69 |
+
def fuse_conv_bn(self, conv, bn):
|
70 |
+
|
71 |
+
std = (bn.running_var + bn.eps).sqrt()
|
72 |
+
bias = bn.bias - bn.running_mean * bn.weight / std
|
73 |
+
|
74 |
+
t = (bn.weight / std).reshape(-1, 1, 1, 1)
|
75 |
+
weights = conv.weight * t
|
76 |
+
|
77 |
+
bn = nn.Identity()
|
78 |
+
conv = nn.Conv2d(
|
79 |
+
in_channels=conv.in_channels,
|
80 |
+
out_channels=conv.out_channels,
|
81 |
+
kernel_size=conv.kernel_size,
|
82 |
+
stride=conv.stride,
|
83 |
+
padding=conv.padding,
|
84 |
+
dilation=conv.dilation,
|
85 |
+
groups=conv.groups,
|
86 |
+
bias=True,
|
87 |
+
padding_mode=conv.padding_mode,
|
88 |
+
)
|
89 |
+
|
90 |
+
conv.weight = torch.nn.Parameter(weights)
|
91 |
+
conv.bias = torch.nn.Parameter(bias)
|
92 |
+
return conv
|
93 |
|
94 |
|
95 |
# ResNet
|
model/yolo.py
CHANGED
@@ -4,6 +4,7 @@ import torch
|
|
4 |
import torch.nn as nn
|
5 |
from loguru import logger
|
6 |
from omegaconf import OmegaConf
|
|
|
7 |
from tools.layer_helper import get_layer_map
|
8 |
|
9 |
|
@@ -32,6 +33,7 @@ class YOLO(nn.Module):
|
|
32 |
layer_type, layer_info = next(iter(layer_spec.items()))
|
33 |
layer_args = layer_info.get("args", {})
|
34 |
source = layer_info.get("source", -1)
|
|
|
35 |
|
36 |
if isinstance(source, str):
|
37 |
source = layer_indices_by_tag[source]
|
@@ -41,7 +43,7 @@ class YOLO(nn.Module):
|
|
41 |
layer_args["nc"] = self.nc
|
42 |
layer_args["ch"] = [output_dim[idx] for idx in source]
|
43 |
|
44 |
-
layer = self.create_layer(layer_type, source, **layer_args)
|
45 |
model_list.append(layer)
|
46 |
|
47 |
if "tags" in layer_info:
|
@@ -55,6 +57,7 @@ class YOLO(nn.Module):
|
|
55 |
|
56 |
def forward(self, x):
|
57 |
y = [x]
|
|
|
58 |
for layer in self.model:
|
59 |
if OmegaConf.is_list(layer.source):
|
60 |
model_input = [y[idx] for idx in layer.source]
|
@@ -62,7 +65,9 @@ class YOLO(nn.Module):
|
|
62 |
model_input = y[layer.source]
|
63 |
x = layer(model_input)
|
64 |
y.append(x)
|
65 |
-
|
|
|
|
|
66 |
|
67 |
def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
|
68 |
if "Conv" in layer_type:
|
@@ -74,10 +79,11 @@ class YOLO(nn.Module):
|
|
74 |
if layer_type == "IDetect":
|
75 |
return None
|
76 |
|
77 |
-
def create_layer(self, layer_type: str, source: Union[int, list], **kwargs):
|
78 |
if layer_type in self.layer_map:
|
79 |
layer = self.layer_map[layer_type](**kwargs)
|
80 |
layer.source = source
|
|
|
81 |
return layer
|
82 |
else:
|
83 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
|
|
4 |
import torch.nn as nn
|
5 |
from loguru import logger
|
6 |
from omegaconf import OmegaConf
|
7 |
+
|
8 |
from tools.layer_helper import get_layer_map
|
9 |
|
10 |
|
|
|
33 |
layer_type, layer_info = next(iter(layer_spec.items()))
|
34 |
layer_args = layer_info.get("args", {})
|
35 |
source = layer_info.get("source", -1)
|
36 |
+
output = layer_info.get("output", False)
|
37 |
|
38 |
if isinstance(source, str):
|
39 |
source = layer_indices_by_tag[source]
|
|
|
43 |
layer_args["nc"] = self.nc
|
44 |
layer_args["ch"] = [output_dim[idx] for idx in source]
|
45 |
|
46 |
+
layer = self.create_layer(layer_type, source, output, **layer_args)
|
47 |
model_list.append(layer)
|
48 |
|
49 |
if "tags" in layer_info:
|
|
|
57 |
|
58 |
def forward(self, x):
|
59 |
y = [x]
|
60 |
+
output = []
|
61 |
for layer in self.model:
|
62 |
if OmegaConf.is_list(layer.source):
|
63 |
model_input = [y[idx] for idx in layer.source]
|
|
|
65 |
model_input = y[layer.source]
|
66 |
x = layer(model_input)
|
67 |
y.append(x)
|
68 |
+
if layer.output:
|
69 |
+
output.append(x)
|
70 |
+
return output
|
71 |
|
72 |
def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
|
73 |
if "Conv" in layer_type:
|
|
|
79 |
if layer_type == "IDetect":
|
80 |
return None
|
81 |
|
82 |
+
def create_layer(self, layer_type: str, source: Union[int, list], output=False, **kwargs):
|
83 |
if layer_type in self.layer_map:
|
84 |
layer = self.layer_map[layer_type](**kwargs)
|
85 |
layer.source = source
|
86 |
+
layer.output = output
|
87 |
return layer
|
88 |
else:
|
89 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
requirements.txt
CHANGED
@@ -1,11 +1,13 @@
|
|
1 |
hydra-core
|
2 |
loguru
|
3 |
numpy
|
|
|
4 |
pytest
|
5 |
pyyaml
|
6 |
requests
|
7 |
rich
|
8 |
torch
|
|
|
9 |
tqdm
|
10 |
Pillow
|
11 |
diskcache
|
|
|
1 |
hydra-core
|
2 |
loguru
|
3 |
numpy
|
4 |
+
Pillow
|
5 |
pytest
|
6 |
pyyaml
|
7 |
requests
|
8 |
rich
|
9 |
torch
|
10 |
+
torchvision
|
11 |
tqdm
|
12 |
Pillow
|
13 |
diskcache
|
tests/test_model/test_yolo.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
import pytest
|
4 |
+
import torch
|
5 |
+
from hydra import compose, initialize
|
6 |
+
from hydra.core.global_hydra import GlobalHydra
|
7 |
+
from omegaconf import DictConfig, OmegaConf
|
8 |
+
|
9 |
+
sys.path.append("./")
|
10 |
+
from model.yolo import YOLO, get_model
|
11 |
+
|
12 |
+
config_path = "../../config/model"
|
13 |
+
config_name = "v7-base"
|
14 |
+
|
15 |
+
|
16 |
+
def test_build_model():
|
17 |
+
|
18 |
+
with initialize(config_path=config_path, version_base=None):
|
19 |
+
model_cfg = compose(config_name=config_name)
|
20 |
+
OmegaConf.set_struct(model_cfg, False)
|
21 |
+
model = YOLO(model_cfg)
|
22 |
+
model.build_model(model_cfg.model)
|
23 |
+
assert len(model.model) == 106
|
24 |
+
|
25 |
+
|
26 |
+
def test_get_model():
|
27 |
+
with initialize(config_path=config_path, version_base=None):
|
28 |
+
model_cfg = compose(config_name=config_name)
|
29 |
+
model = get_model(model_cfg)
|
30 |
+
assert isinstance(model, YOLO)
|
31 |
+
|
32 |
+
|
33 |
+
def test_yolo_forward_output_shape():
|
34 |
+
with initialize(config_path=config_path, version_base=None):
|
35 |
+
model_cfg = compose(config_name=config_name)
|
36 |
+
|
37 |
+
model = get_model(model_cfg)
|
38 |
+
# 2 - batch size, 3 - number of channels, 640x640 - image dimensions
|
39 |
+
dummy_input = torch.rand(2, 3, 640, 640)
|
40 |
+
|
41 |
+
# Forward pass through the model
|
42 |
+
output = model(dummy_input)
|
43 |
+
output_shape = [x.shape for x in output[-1]]
|
44 |
+
assert output_shape == [
|
45 |
+
torch.Size([2, 3, 20, 20, 85]),
|
46 |
+
torch.Size([2, 3, 80, 80, 85]),
|
47 |
+
torch.Size([2, 3, 40, 40, 85]),
|
48 |
+
]
|
tests/test_utils/test_dataaugment.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
import pytest
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from torchvision.transforms import functional as TF
|
7 |
+
|
8 |
+
sys.path.append("./")
|
9 |
+
from utils.data_augment import Compose, Mosaic, RandomHorizontalFlip
|
10 |
+
|
11 |
+
|
12 |
+
def test_random_horizontal_flip():
|
13 |
+
# Create a mock image and bounding boxes
|
14 |
+
img = Image.new("RGB", (100, 100), color="red")
|
15 |
+
boxes = torch.tensor([[1, 0.1, 0.1, 0.9, 0.9]]) # class, xmin, ymin, xmax, ymax
|
16 |
+
|
17 |
+
flip_transform = RandomHorizontalFlip(prob=1) # Set probability to 1 to ensure flip
|
18 |
+
flipped_img, flipped_boxes = flip_transform(img, boxes)
|
19 |
+
|
20 |
+
# Assert image is flipped by comparing it to a manually flipped image
|
21 |
+
assert TF.hflip(img) == flipped_img
|
22 |
+
|
23 |
+
# Assert bounding boxes are flipped correctly
|
24 |
+
expected_boxes = torch.tensor([[1, 0.1, 0.1, 0.9, 0.9]])
|
25 |
+
expected_boxes[:, [1, 3]] = 1 - expected_boxes[:, [3, 1]]
|
26 |
+
assert torch.allclose(flipped_boxes, expected_boxes), "Bounding boxes were not flipped correctly"
|
27 |
+
|
28 |
+
|
29 |
+
def test_compose():
|
30 |
+
# Define two mock transforms that simply return the inputs
|
31 |
+
def mock_transform(image, boxes):
|
32 |
+
return image, boxes
|
33 |
+
|
34 |
+
compose = Compose([mock_transform, mock_transform])
|
35 |
+
img = Image.new("RGB", (10, 10), color="blue")
|
36 |
+
boxes = torch.tensor([[0, 0.2, 0.2, 0.8, 0.8]])
|
37 |
+
|
38 |
+
transformed_img, transformed_boxes = compose(img, boxes)
|
39 |
+
|
40 |
+
assert transformed_img == img, "Image should not be altered"
|
41 |
+
assert torch.equal(transformed_boxes, boxes), "Boxes should not be altered"
|
42 |
+
|
43 |
+
|
44 |
+
def test_mosaic():
|
45 |
+
img = Image.new("RGB", (100, 100), color="green")
|
46 |
+
boxes = torch.tensor([[0, 0.25, 0.25, 0.75, 0.75]])
|
47 |
+
|
48 |
+
# Mock parent with image_size and get_more_data method
|
49 |
+
class MockParent:
|
50 |
+
image_size = 100
|
51 |
+
|
52 |
+
def get_more_data(self, num_images):
|
53 |
+
return [(img, boxes) for _ in range(num_images)]
|
54 |
+
|
55 |
+
mosaic = Mosaic(prob=1) # Ensure mosaic is applied
|
56 |
+
mosaic.set_parent(MockParent())
|
57 |
+
|
58 |
+
mosaic_img, mosaic_boxes = mosaic(img, boxes)
|
59 |
+
|
60 |
+
# Checks here would depend on the exact expected behavior of the mosaic function,
|
61 |
+
# such as dimensions and content of the output image and boxes.
|
62 |
+
|
63 |
+
assert mosaic_img.size == (200, 200), "Mosaic image size should be doubled"
|
64 |
+
assert len(mosaic_boxes) > 0, "Should have some bounding boxes"
|
tools/layer_helper.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
import inspect
|
|
|
2 |
import torch.nn as nn
|
|
|
3 |
from model import module
|
4 |
|
5 |
|
|
|
1 |
import inspect
|
2 |
+
|
3 |
import torch.nn as nn
|
4 |
+
|
5 |
from model import module
|
6 |
|
7 |
|
tools/log_helper.py
CHANGED
@@ -12,6 +12,7 @@ Example:
|
|
12 |
"""
|
13 |
|
14 |
import sys
|
|
|
15 |
from loguru import logger
|
16 |
|
17 |
|
|
|
12 |
"""
|
13 |
|
14 |
import sys
|
15 |
+
|
16 |
from loguru import logger
|
17 |
|
18 |
|
train.py
CHANGED
@@ -1,13 +1,16 @@
|
|
|
|
1 |
from loguru import logger
|
|
|
|
|
2 |
from model.yolo import get_model
|
3 |
from tools.log_helper import custom_logger
|
|
|
4 |
from utils.get_dataset import prepare_dataset
|
5 |
-
import hydra
|
6 |
-
from config.config import Config
|
7 |
|
8 |
|
9 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
10 |
def main(cfg: Config):
|
|
|
11 |
if cfg.download.auto:
|
12 |
prepare_dataset(cfg.download)
|
13 |
|
|
|
1 |
+
import hydra
|
2 |
from loguru import logger
|
3 |
+
|
4 |
+
from config.config import Config
|
5 |
from model.yolo import get_model
|
6 |
from tools.log_helper import custom_logger
|
7 |
+
from utils.dataloader import YoloDataset
|
8 |
from utils.get_dataset import prepare_dataset
|
|
|
|
|
9 |
|
10 |
|
11 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
12 |
def main(cfg: Config):
|
13 |
+
dataset = YoloDataset(cfg)
|
14 |
if cfg.download.auto:
|
15 |
prepare_dataset(cfg.download)
|
16 |
|
utils/data_augment.py
CHANGED
@@ -1,15 +1,15 @@
|
|
1 |
-
from PIL import Image
|
2 |
import numpy as np
|
3 |
import torch
|
|
|
4 |
from torchvision.transforms import functional as TF
|
5 |
-
from torchvision.transforms.functional import to_tensor, to_pil_image
|
6 |
|
7 |
|
8 |
class Compose:
|
9 |
"""Composes several transforms together."""
|
10 |
|
11 |
-
def __init__(self, transforms):
|
12 |
self.transforms = transforms
|
|
|
13 |
|
14 |
for transform in self.transforms:
|
15 |
if hasattr(transform, "set_parent"):
|
@@ -20,11 +20,8 @@ class Compose:
|
|
20 |
image, boxes = transform(image, boxes)
|
21 |
return image, boxes
|
22 |
|
23 |
-
def get_more_data(self):
|
24 |
-
raise NotImplementedError("This method should be overridden by subclass instances!")
|
25 |
-
|
26 |
|
27 |
-
class
|
28 |
"""Randomly horizontally flips the image along with the bounding boxes."""
|
29 |
|
30 |
def __init__(self, prob=0.5):
|
@@ -37,7 +34,7 @@ class RandomHorizontalFlip:
|
|
37 |
return image, boxes
|
38 |
|
39 |
|
40 |
-
class
|
41 |
"""Randomly vertically flips the image along with the bounding boxes."""
|
42 |
|
43 |
def __init__(self, prob=0.5):
|
@@ -90,6 +87,7 @@ class Mosaic:
|
|
90 |
all_labels.append(adjusted_boxes)
|
91 |
|
92 |
all_labels = torch.cat(all_labels, dim=0)
|
|
|
93 |
return mosaic_image, all_labels
|
94 |
|
95 |
|
@@ -118,10 +116,10 @@ class MixUp:
|
|
118 |
lam = np.random.beta(self.alpha, self.alpha) if self.alpha > 0 else 0.5
|
119 |
|
120 |
# Mix images
|
121 |
-
image1, image2 = to_tensor(image), to_tensor(image2)
|
122 |
mixed_image = lam * image1 + (1 - lam) * image2
|
123 |
|
124 |
# Mix bounding boxes
|
125 |
mixed_boxes = torch.cat([lam * boxes, (1 - lam) * boxes2])
|
126 |
|
127 |
-
return to_pil_image(mixed_image), mixed_boxes
|
|
|
|
|
1 |
import numpy as np
|
2 |
import torch
|
3 |
+
from PIL import Image
|
4 |
from torchvision.transforms import functional as TF
|
|
|
5 |
|
6 |
|
7 |
class Compose:
|
8 |
"""Composes several transforms together."""
|
9 |
|
10 |
+
def __init__(self, transforms, image_size: int = 640):
|
11 |
self.transforms = transforms
|
12 |
+
self.image_size = image_size
|
13 |
|
14 |
for transform in self.transforms:
|
15 |
if hasattr(transform, "set_parent"):
|
|
|
20 |
image, boxes = transform(image, boxes)
|
21 |
return image, boxes
|
22 |
|
|
|
|
|
|
|
23 |
|
24 |
+
class HorizontalFlip:
|
25 |
"""Randomly horizontally flips the image along with the bounding boxes."""
|
26 |
|
27 |
def __init__(self, prob=0.5):
|
|
|
34 |
return image, boxes
|
35 |
|
36 |
|
37 |
+
class VerticalFlip:
|
38 |
"""Randomly vertically flips the image along with the bounding boxes."""
|
39 |
|
40 |
def __init__(self, prob=0.5):
|
|
|
87 |
all_labels.append(adjusted_boxes)
|
88 |
|
89 |
all_labels = torch.cat(all_labels, dim=0)
|
90 |
+
mosaic_image = mosaic_image.resize((img_sz, img_sz))
|
91 |
return mosaic_image, all_labels
|
92 |
|
93 |
|
|
|
116 |
lam = np.random.beta(self.alpha, self.alpha) if self.alpha > 0 else 0.5
|
117 |
|
118 |
# Mix images
|
119 |
+
image1, image2 = TF.to_tensor(image), TF.to_tensor(image2)
|
120 |
mixed_image = lam * image1 + (1 - lam) * image2
|
121 |
|
122 |
# Mix bounding boxes
|
123 |
mixed_boxes = torch.cat([lam * boxes, (1 - lam) * boxes2])
|
124 |
|
125 |
+
return TF.to_pil_image(mixed_image), mixed_boxes
|
utils/dataloader.py
CHANGED
@@ -1,26 +1,30 @@
|
|
1 |
-
from
|
2 |
-
from
|
3 |
|
|
|
4 |
import hydra
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
-
from torch.utils.data import Dataset
|
8 |
from loguru import logger
|
|
|
|
|
|
|
9 |
from tqdm.rich import tqdm
|
10 |
-
|
11 |
-
from
|
12 |
-
from drawer import draw_bboxes
|
13 |
-
from data_augment import Compose, RandomHorizontalFlip, RandomVerticalFlip, Mosaic, MixUp
|
14 |
|
15 |
|
16 |
class YoloDataset(Dataset):
|
17 |
-
def __init__(self,
|
|
|
|
|
18 |
phase_name = dataset_cfg.get(phase, phase)
|
19 |
self.image_size = image_size
|
20 |
|
21 |
-
|
|
|
22 |
self.transform.get_more_data = self.get_more_data
|
23 |
-
self.transform.image_size = self.image_size
|
24 |
self.data = self.load_data(dataset_cfg.path, phase_name)
|
25 |
|
26 |
def load_data(self, dataset_path, phase_name):
|
@@ -121,17 +125,55 @@ class YoloDataset(Dataset):
|
|
121 |
img, bboxes = self.get_data(idx)
|
122 |
if self.transform:
|
123 |
img, bboxes = self.transform(img, bboxes)
|
|
|
124 |
return img, bboxes
|
125 |
|
126 |
def __len__(self) -> int:
|
127 |
return len(self.data)
|
128 |
|
129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
@hydra.main(config_path="../config", config_name="config", version_base=None)
|
131 |
def main(cfg):
|
132 |
-
|
133 |
-
|
134 |
-
draw_bboxes(*dataset[0])
|
135 |
|
136 |
|
137 |
if __name__ == "__main__":
|
|
|
1 |
+
from os import listdir, path
|
2 |
+
from typing import List, Tuple, Union
|
3 |
|
4 |
+
import diskcache as dc
|
5 |
import hydra
|
6 |
import numpy as np
|
7 |
import torch
|
|
|
8 |
from loguru import logger
|
9 |
+
from PIL import Image
|
10 |
+
from torch.utils.data import DataLoader, Dataset
|
11 |
+
from torchvision.transforms import functional as TF
|
12 |
from tqdm.rich import tqdm
|
13 |
+
|
14 |
+
from utils.data_augment import Compose, HorizontalFlip, MixUp, Mosaic, VerticalFlip
|
15 |
+
from utils.drawer import draw_bboxes
|
|
|
16 |
|
17 |
|
18 |
class YoloDataset(Dataset):
|
19 |
+
def __init__(self, config: dict, phase: str = "train", image_size: int = 640):
|
20 |
+
dataset_cfg = config.data
|
21 |
+
augment_cfg = config.augmentation
|
22 |
phase_name = dataset_cfg.get(phase, phase)
|
23 |
self.image_size = image_size
|
24 |
|
25 |
+
transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
|
26 |
+
self.transform = Compose(transforms, self.image_size)
|
27 |
self.transform.get_more_data = self.get_more_data
|
|
|
28 |
self.data = self.load_data(dataset_cfg.path, phase_name)
|
29 |
|
30 |
def load_data(self, dataset_path, phase_name):
|
|
|
125 |
img, bboxes = self.get_data(idx)
|
126 |
if self.transform:
|
127 |
img, bboxes = self.transform(img, bboxes)
|
128 |
+
img = TF.to_tensor(img)
|
129 |
return img, bboxes
|
130 |
|
131 |
def __len__(self) -> int:
|
132 |
return len(self.data)
|
133 |
|
134 |
|
135 |
+
class YoloDataLoader(DataLoader):
|
136 |
+
def __init__(self, config: dict):
|
137 |
+
"""Initializes the YoloDataLoader with hydra-config files."""
|
138 |
+
hyper = config.hyper.data
|
139 |
+
dataset = YoloDataset(config)
|
140 |
+
|
141 |
+
super().__init__(
|
142 |
+
dataset,
|
143 |
+
batch_size=hyper.batch_size,
|
144 |
+
shuffle=hyper.shuffle,
|
145 |
+
num_workers=hyper.num_workers,
|
146 |
+
pin_memory=hyper.pin_memory,
|
147 |
+
collate_fn=self.collate_fn,
|
148 |
+
)
|
149 |
+
|
150 |
+
def collate_fn(self, batch: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
151 |
+
"""
|
152 |
+
A collate function to handle batching of images and their corresponding targets.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
batch (list of tuples): Each tuple contains:
|
156 |
+
- image (torch.Tensor): The image tensor.
|
157 |
+
- labels (torch.Tensor): The tensor of labels for the image.
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
Tuple[torch.Tensor, List[torch.Tensor]]: A tuple containing:
|
161 |
+
- A tensor of batched images.
|
162 |
+
- A list of tensors, each corresponding to bboxes for each image in the batch.
|
163 |
+
"""
|
164 |
+
images = torch.stack([item[0] for item in batch])
|
165 |
+
targets = [item[1] for item in batch]
|
166 |
+
return images, targets
|
167 |
+
|
168 |
+
|
169 |
+
def get_dataloader(config):
|
170 |
+
return YoloDataLoader(config)
|
171 |
+
|
172 |
+
|
173 |
@hydra.main(config_path="../config", config_name="config", version_base=None)
|
174 |
def main(cfg):
|
175 |
+
dataloader = get_dataloader(cfg)
|
176 |
+
draw_bboxes(next(iter(dataloader)))
|
|
|
177 |
|
178 |
|
179 |
if __name__ == "__main__":
|
utils/drawer.py
CHANGED
@@ -1,23 +1,31 @@
|
|
|
|
|
|
|
|
|
|
1 |
from PIL import Image, ImageDraw, ImageFont
|
|
|
2 |
|
3 |
|
4 |
-
def draw_bboxes(img, bboxes):
|
5 |
"""
|
6 |
Draw bounding boxes on an image.
|
7 |
|
8 |
Args:
|
9 |
-
-
|
10 |
-
- bboxes (
|
|
|
11 |
"""
|
12 |
-
#
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
|
16 |
-
try:
|
17 |
-
font = ImageFont.truetype("arial.ttf", 30)
|
18 |
-
except IOError:
|
19 |
-
font = ImageFont.load_default(30)
|
20 |
width, height = img.size
|
|
|
21 |
|
22 |
for bbox in bboxes:
|
23 |
class_id, x_min, y_min, x_max, y_max = bbox
|
@@ -26,7 +34,8 @@ def draw_bboxes(img, bboxes):
|
|
26 |
y_min = y_min * height
|
27 |
y_max = y_max * height
|
28 |
shape = [(x_min, y_min), (x_max, y_max)]
|
29 |
-
draw.rectangle(shape, outline="red", width=
|
30 |
draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
|
31 |
|
32 |
-
img.save("
|
|
|
|
1 |
+
from typing import List, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from loguru import logger
|
5 |
from PIL import Image, ImageDraw, ImageFont
|
6 |
+
from torchvision.transforms.functional import to_pil_image
|
7 |
|
8 |
|
9 |
+
def draw_bboxes(img: Union[Image.Image, torch.Tensor], bboxes: List[List[Union[int, float]]]):
|
10 |
"""
|
11 |
Draw bounding boxes on an image.
|
12 |
|
13 |
Args:
|
14 |
+
- img (PIL Image or torch.Tensor): Image on which to draw the bounding boxes.
|
15 |
+
- bboxes (List of Lists/Tensors): Bounding boxes with [class_id, x_min, y_min, x_max, y_max],
|
16 |
+
where coordinates are normalized [0, 1].
|
17 |
"""
|
18 |
+
# Convert tensor image to PIL Image if necessary
|
19 |
+
if isinstance(img, torch.Tensor):
|
20 |
+
if img.dim() > 3:
|
21 |
+
logger.info("Multi-frame tensor detected, using the first image.")
|
22 |
+
img = img[0]
|
23 |
+
bboxes = bboxes[0]
|
24 |
+
img = to_pil_image(img)
|
25 |
|
26 |
+
draw = ImageDraw.Draw(img)
|
|
|
|
|
|
|
|
|
27 |
width, height = img.size
|
28 |
+
font = ImageFont.load_default(30)
|
29 |
|
30 |
for bbox in bboxes:
|
31 |
class_id, x_min, y_min, x_max, y_max = bbox
|
|
|
34 |
y_min = y_min * height
|
35 |
y_max = y_max * height
|
36 |
shape = [(x_min, y_min), (x_max, y_max)]
|
37 |
+
draw.rectangle(shape, outline="red", width=3)
|
38 |
draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
|
39 |
|
40 |
+
img.save("visualize.jpg") # Save the image with annotations
|
41 |
+
logger.info("Saved visualize image at visualize.png")
|
utils/get_dataset.py
CHANGED
@@ -2,8 +2,8 @@ import os
|
|
2 |
import zipfile
|
3 |
|
4 |
import hydra
|
5 |
-
from loguru import logger
|
6 |
import requests
|
|
|
7 |
from tqdm.rich import tqdm
|
8 |
|
9 |
|
|
|
2 |
import zipfile
|
3 |
|
4 |
import hydra
|
|
|
5 |
import requests
|
6 |
+
from loguru import logger
|
7 |
from tqdm.rich import tqdm
|
8 |
|
9 |
|