π [Merge] branch 'SETUP' into DATASET
Browse files- .github/workflows/main.yaml +26 -0
- .pre-commit-config.yaml +6 -0
- README.md +14 -0
- config/config.py +1 -1
- config/model/v7-base.yaml +1 -0
- model/module.py +30 -4
- model/yolo.py +9 -3
- requirements.txt +2 -0
- tests/test_model/test_yolo.py +48 -0
- tests/test_utils/test_dataaugment.py +64 -0
- tools/layer_helper.py +2 -0
- tools/log_helper.py +1 -0
- train.py +5 -2
- utils/data_augment.py +4 -6
- utils/dataloader.py +20 -12
- utils/get_dataset.py +1 -1
.github/workflows/main.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: YOLOv9 - Model test
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches: [ main ]
|
6 |
+
pull_request:
|
7 |
+
branches: [ main ]
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
build:
|
11 |
+
|
12 |
+
runs-on: ubuntu-latest
|
13 |
+
|
14 |
+
steps:
|
15 |
+
- uses: actions/checkout@v2
|
16 |
+
- name: Set up Python 3.8
|
17 |
+
uses: actions/setup-python@v2
|
18 |
+
with:
|
19 |
+
python-version: 3.8
|
20 |
+
- name: Install dependencies
|
21 |
+
run: |
|
22 |
+
python -m pip install --upgrade pip
|
23 |
+
pip install -r requirements.txt
|
24 |
+
- name: Test with pytest
|
25 |
+
run: |
|
26 |
+
pytest
|
.pre-commit-config.yaml
CHANGED
@@ -6,3 +6,9 @@ repos:
|
|
6 |
language_version: python3 # Specify the Python version
|
7 |
exclude: '.*\.yaml$' # Regex pattern to exclude all YAML files
|
8 |
args: ["--line-length", "120"] # Set max line length to 100 characters
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
language_version: python3 # Specify the Python version
|
7 |
exclude: '.*\.yaml$' # Regex pattern to exclude all YAML files
|
8 |
args: ["--line-length", "120"] # Set max line length to 100 characters
|
9 |
+
|
10 |
+
- repo: https://github.com/pre-commit/mirrors-isort
|
11 |
+
rev: v5.10.1 # Use the appropriate version or "stable" for the latest stable release
|
12 |
+
hooks:
|
13 |
+
- id: isort
|
14 |
+
args: ["--profile", "black", "--verbose"]
|
README.md
CHANGED
@@ -1,6 +1,20 @@
|
|
1 |
# YOLOv9-MIT
|
2 |
An MIT license rewrite of YOLOv9
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
## To-Do Lists
|
5 |
- [ ] Project Setup
|
6 |
- [X] requirements
|
|
|
1 |
# YOLOv9-MIT
|
2 |
An MIT license rewrite of YOLOv9
|
3 |
|
4 |
+

|
5 |
+
> [!IMPORTANT]
|
6 |
+
> This project is currently a Work In Progress and may undergo significant changes. It is not recommended for use in production environments until further notice. Please check back regularly for updates.
|
7 |
+
>
|
8 |
+
> Use of this code is at your own risk and discretion. It is advisable to consult with the project owner before deploying or integrating into any critical systems.
|
9 |
+
|
10 |
+
## Contributing
|
11 |
+
|
12 |
+
While the project's structure is still being finalized, we ask that potential contributors wait for these foundational decisions to be made. We greatly appreciate your patience and are excited to welcome contributions from the community once we are ready. Alternatively, you are welcome to propose functions that should be implemented based on the original YOLO version or suggest other enhancements!
|
13 |
+
|
14 |
+
If you are interested in contributing, please keep an eye on project updates or contact us directly at [[email protected]](mailto:[email protected]) for more information.
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
## To-Do Lists
|
19 |
- [ ] Project Setup
|
20 |
- [X] requirements
|
config/config.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from dataclasses import dataclass
|
2 |
-
from typing import
|
3 |
|
4 |
|
5 |
@dataclass
|
|
|
1 |
from dataclasses import dataclass
|
2 |
+
from typing import Dict, List, Union
|
3 |
|
4 |
|
5 |
@dataclass
|
config/model/v7-base.yaml
CHANGED
@@ -241,3 +241,4 @@ model:
|
|
241 |
- [36,75, 76,55, 72,146] # P4/16
|
242 |
- [142,110, 192,243, 459,401] # P5/32
|
243 |
source: [102, 103, 104]
|
|
|
|
241 |
- [36,75, 76,55, 72,146] # P4/16
|
242 |
- [142,110, 192,243, 459,401] # P5/32
|
243 |
source: [102, 103, 104]
|
244 |
+
output: True
|
model/module.py
CHANGED
@@ -11,10 +11,10 @@ class Conv(nn.Module):
|
|
11 |
out_channels,
|
12 |
kernel_size,
|
13 |
stride=1,
|
14 |
-
padding=
|
15 |
dilation=1,
|
16 |
groups=1,
|
17 |
-
act=nn.
|
18 |
bias=False,
|
19 |
auto_padding=True,
|
20 |
padding_mode="zeros",
|
@@ -48,10 +48,12 @@ class Conv(nn.Module):
|
|
48 |
# RepVGG
|
49 |
class RepConv(nn.Module):
|
50 |
# https://github.com/DingXiaoH/RepVGG
|
51 |
-
def __init__(
|
|
|
|
|
52 |
|
53 |
super().__init__()
|
54 |
-
|
55 |
self.conv1 = Conv(in_channels, out_channels, kernel_size, stride, groups=groups, act=False)
|
56 |
self.conv2 = Conv(in_channels, out_channels, 1, stride, groups=groups, act=False)
|
57 |
self.act = act if isinstance(act, nn.Module) else nn.Identity()
|
@@ -64,6 +66,30 @@ class RepConv(nn.Module):
|
|
64 |
|
65 |
# to be implement
|
66 |
# def fuse_convs(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
|
69 |
# ResNet
|
|
|
11 |
out_channels,
|
12 |
kernel_size,
|
13 |
stride=1,
|
14 |
+
padding=None,
|
15 |
dilation=1,
|
16 |
groups=1,
|
17 |
+
act=nn.SiLU(),
|
18 |
bias=False,
|
19 |
auto_padding=True,
|
20 |
padding_mode="zeros",
|
|
|
48 |
# RepVGG
|
49 |
class RepConv(nn.Module):
|
50 |
# https://github.com/DingXiaoH/RepVGG
|
51 |
+
def __init__(
|
52 |
+
self, in_channels, out_channels, kernel_size=3, padding=None, stride=1, groups=1, act=nn.SiLU(), deploy=False
|
53 |
+
):
|
54 |
|
55 |
super().__init__()
|
56 |
+
self.deploy = deploy
|
57 |
self.conv1 = Conv(in_channels, out_channels, kernel_size, stride, groups=groups, act=False)
|
58 |
self.conv2 = Conv(in_channels, out_channels, 1, stride, groups=groups, act=False)
|
59 |
self.act = act if isinstance(act, nn.Module) else nn.Identity()
|
|
|
66 |
|
67 |
# to be implement
|
68 |
# def fuse_convs(self):
|
69 |
+
def fuse_conv_bn(self, conv, bn):
|
70 |
+
|
71 |
+
std = (bn.running_var + bn.eps).sqrt()
|
72 |
+
bias = bn.bias - bn.running_mean * bn.weight / std
|
73 |
+
|
74 |
+
t = (bn.weight / std).reshape(-1, 1, 1, 1)
|
75 |
+
weights = conv.weight * t
|
76 |
+
|
77 |
+
bn = nn.Identity()
|
78 |
+
conv = nn.Conv2d(
|
79 |
+
in_channels=conv.in_channels,
|
80 |
+
out_channels=conv.out_channels,
|
81 |
+
kernel_size=conv.kernel_size,
|
82 |
+
stride=conv.stride,
|
83 |
+
padding=conv.padding,
|
84 |
+
dilation=conv.dilation,
|
85 |
+
groups=conv.groups,
|
86 |
+
bias=True,
|
87 |
+
padding_mode=conv.padding_mode,
|
88 |
+
)
|
89 |
+
|
90 |
+
conv.weight = torch.nn.Parameter(weights)
|
91 |
+
conv.bias = torch.nn.Parameter(bias)
|
92 |
+
return conv
|
93 |
|
94 |
|
95 |
# ResNet
|
model/yolo.py
CHANGED
@@ -4,6 +4,7 @@ import torch
|
|
4 |
import torch.nn as nn
|
5 |
from loguru import logger
|
6 |
from omegaconf import OmegaConf
|
|
|
7 |
from tools.layer_helper import get_layer_map
|
8 |
|
9 |
|
@@ -32,6 +33,7 @@ class YOLO(nn.Module):
|
|
32 |
layer_type, layer_info = next(iter(layer_spec.items()))
|
33 |
layer_args = layer_info.get("args", {})
|
34 |
source = layer_info.get("source", -1)
|
|
|
35 |
|
36 |
if isinstance(source, str):
|
37 |
source = layer_indices_by_tag[source]
|
@@ -41,7 +43,7 @@ class YOLO(nn.Module):
|
|
41 |
layer_args["nc"] = self.nc
|
42 |
layer_args["ch"] = [output_dim[idx] for idx in source]
|
43 |
|
44 |
-
layer = self.create_layer(layer_type, source, **layer_args)
|
45 |
model_list.append(layer)
|
46 |
|
47 |
if "tags" in layer_info:
|
@@ -55,6 +57,7 @@ class YOLO(nn.Module):
|
|
55 |
|
56 |
def forward(self, x):
|
57 |
y = [x]
|
|
|
58 |
for layer in self.model:
|
59 |
if OmegaConf.is_list(layer.source):
|
60 |
model_input = [y[idx] for idx in layer.source]
|
@@ -62,7 +65,9 @@ class YOLO(nn.Module):
|
|
62 |
model_input = y[layer.source]
|
63 |
x = layer(model_input)
|
64 |
y.append(x)
|
65 |
-
|
|
|
|
|
66 |
|
67 |
def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
|
68 |
if "Conv" in layer_type:
|
@@ -74,10 +79,11 @@ class YOLO(nn.Module):
|
|
74 |
if layer_type == "IDetect":
|
75 |
return None
|
76 |
|
77 |
-
def create_layer(self, layer_type: str, source: Union[int, list], **kwargs):
|
78 |
if layer_type in self.layer_map:
|
79 |
layer = self.layer_map[layer_type](**kwargs)
|
80 |
layer.source = source
|
|
|
81 |
return layer
|
82 |
else:
|
83 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
|
|
4 |
import torch.nn as nn
|
5 |
from loguru import logger
|
6 |
from omegaconf import OmegaConf
|
7 |
+
|
8 |
from tools.layer_helper import get_layer_map
|
9 |
|
10 |
|
|
|
33 |
layer_type, layer_info = next(iter(layer_spec.items()))
|
34 |
layer_args = layer_info.get("args", {})
|
35 |
source = layer_info.get("source", -1)
|
36 |
+
output = layer_info.get("output", False)
|
37 |
|
38 |
if isinstance(source, str):
|
39 |
source = layer_indices_by_tag[source]
|
|
|
43 |
layer_args["nc"] = self.nc
|
44 |
layer_args["ch"] = [output_dim[idx] for idx in source]
|
45 |
|
46 |
+
layer = self.create_layer(layer_type, source, output, **layer_args)
|
47 |
model_list.append(layer)
|
48 |
|
49 |
if "tags" in layer_info:
|
|
|
57 |
|
58 |
def forward(self, x):
|
59 |
y = [x]
|
60 |
+
output = []
|
61 |
for layer in self.model:
|
62 |
if OmegaConf.is_list(layer.source):
|
63 |
model_input = [y[idx] for idx in layer.source]
|
|
|
65 |
model_input = y[layer.source]
|
66 |
x = layer(model_input)
|
67 |
y.append(x)
|
68 |
+
if layer.output:
|
69 |
+
output.append(x)
|
70 |
+
return output
|
71 |
|
72 |
def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
|
73 |
if "Conv" in layer_type:
|
|
|
79 |
if layer_type == "IDetect":
|
80 |
return None
|
81 |
|
82 |
+
def create_layer(self, layer_type: str, source: Union[int, list], output=False, **kwargs):
|
83 |
if layer_type in self.layer_map:
|
84 |
layer = self.layer_map[layer_type](**kwargs)
|
85 |
layer.source = source
|
86 |
+
layer.output = output
|
87 |
return layer
|
88 |
else:
|
89 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
requirements.txt
CHANGED
@@ -1,11 +1,13 @@
|
|
1 |
hydra-core
|
2 |
loguru
|
3 |
numpy
|
|
|
4 |
pytest
|
5 |
pyyaml
|
6 |
requests
|
7 |
rich
|
8 |
torch
|
|
|
9 |
tqdm
|
10 |
Pillow
|
11 |
diskcache
|
|
|
1 |
hydra-core
|
2 |
loguru
|
3 |
numpy
|
4 |
+
Pillow
|
5 |
pytest
|
6 |
pyyaml
|
7 |
requests
|
8 |
rich
|
9 |
torch
|
10 |
+
torchvision
|
11 |
tqdm
|
12 |
Pillow
|
13 |
diskcache
|
tests/test_model/test_yolo.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
import pytest
|
4 |
+
import torch
|
5 |
+
from hydra import compose, initialize
|
6 |
+
from hydra.core.global_hydra import GlobalHydra
|
7 |
+
from omegaconf import DictConfig, OmegaConf
|
8 |
+
|
9 |
+
sys.path.append("./")
|
10 |
+
from model.yolo import YOLO, get_model
|
11 |
+
|
12 |
+
config_path = "../../config/model"
|
13 |
+
config_name = "v7-base"
|
14 |
+
|
15 |
+
|
16 |
+
def test_build_model():
|
17 |
+
|
18 |
+
with initialize(config_path=config_path, version_base=None):
|
19 |
+
model_cfg = compose(config_name=config_name)
|
20 |
+
OmegaConf.set_struct(model_cfg, False)
|
21 |
+
model = YOLO(model_cfg)
|
22 |
+
model.build_model(model_cfg.model)
|
23 |
+
assert len(model.model) == 106
|
24 |
+
|
25 |
+
|
26 |
+
def test_get_model():
|
27 |
+
with initialize(config_path=config_path, version_base=None):
|
28 |
+
model_cfg = compose(config_name=config_name)
|
29 |
+
model = get_model(model_cfg)
|
30 |
+
assert isinstance(model, YOLO)
|
31 |
+
|
32 |
+
|
33 |
+
def test_yolo_forward_output_shape():
|
34 |
+
with initialize(config_path=config_path, version_base=None):
|
35 |
+
model_cfg = compose(config_name=config_name)
|
36 |
+
|
37 |
+
model = get_model(model_cfg)
|
38 |
+
# 2 - batch size, 3 - number of channels, 640x640 - image dimensions
|
39 |
+
dummy_input = torch.rand(2, 3, 640, 640)
|
40 |
+
|
41 |
+
# Forward pass through the model
|
42 |
+
output = model(dummy_input)
|
43 |
+
output_shape = [x.shape for x in output[-1]]
|
44 |
+
assert output_shape == [
|
45 |
+
torch.Size([2, 3, 20, 20, 85]),
|
46 |
+
torch.Size([2, 3, 80, 80, 85]),
|
47 |
+
torch.Size([2, 3, 40, 40, 85]),
|
48 |
+
]
|
tests/test_utils/test_dataaugment.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
import pytest
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from torchvision.transforms import functional as TF
|
7 |
+
|
8 |
+
sys.path.append("./")
|
9 |
+
from utils.dataargument import Compose, Mosaic, RandomHorizontalFlip
|
10 |
+
|
11 |
+
|
12 |
+
def test_random_horizontal_flip():
|
13 |
+
# Create a mock image and bounding boxes
|
14 |
+
img = Image.new("RGB", (100, 100), color="red")
|
15 |
+
boxes = torch.tensor([[1, 0.1, 0.1, 0.9, 0.9]]) # class, xmin, ymin, xmax, ymax
|
16 |
+
|
17 |
+
flip_transform = RandomHorizontalFlip(prob=1) # Set probability to 1 to ensure flip
|
18 |
+
flipped_img, flipped_boxes = flip_transform(img, boxes)
|
19 |
+
|
20 |
+
# Assert image is flipped by comparing it to a manually flipped image
|
21 |
+
assert TF.hflip(img) == flipped_img
|
22 |
+
|
23 |
+
# Assert bounding boxes are flipped correctly
|
24 |
+
expected_boxes = torch.tensor([[1, 0.1, 0.1, 0.9, 0.9]])
|
25 |
+
expected_boxes[:, [1, 3]] = 1 - expected_boxes[:, [3, 1]]
|
26 |
+
assert torch.allclose(flipped_boxes, expected_boxes), "Bounding boxes were not flipped correctly"
|
27 |
+
|
28 |
+
|
29 |
+
def test_compose():
|
30 |
+
# Define two mock transforms that simply return the inputs
|
31 |
+
def mock_transform(image, boxes):
|
32 |
+
return image, boxes
|
33 |
+
|
34 |
+
compose = Compose([mock_transform, mock_transform])
|
35 |
+
img = Image.new("RGB", (10, 10), color="blue")
|
36 |
+
boxes = torch.tensor([[0, 0.2, 0.2, 0.8, 0.8]])
|
37 |
+
|
38 |
+
transformed_img, transformed_boxes = compose(img, boxes)
|
39 |
+
|
40 |
+
assert transformed_img == img, "Image should not be altered"
|
41 |
+
assert torch.equal(transformed_boxes, boxes), "Boxes should not be altered"
|
42 |
+
|
43 |
+
|
44 |
+
def test_mosaic():
|
45 |
+
img = Image.new("RGB", (100, 100), color="green")
|
46 |
+
boxes = torch.tensor([[0, 0.25, 0.25, 0.75, 0.75]])
|
47 |
+
|
48 |
+
# Mock parent with image_size and get_more_data method
|
49 |
+
class MockParent:
|
50 |
+
image_size = 100
|
51 |
+
|
52 |
+
def get_more_data(self, num_images):
|
53 |
+
return [(img, boxes) for _ in range(num_images)]
|
54 |
+
|
55 |
+
mosaic = Mosaic(prob=1) # Ensure mosaic is applied
|
56 |
+
mosaic.set_parent(MockParent())
|
57 |
+
|
58 |
+
mosaic_img, mosaic_boxes = mosaic(img, boxes)
|
59 |
+
|
60 |
+
# Checks here would depend on the exact expected behavior of the mosaic function,
|
61 |
+
# such as dimensions and content of the output image and boxes.
|
62 |
+
|
63 |
+
assert mosaic_img.size == (200, 200), "Mosaic image size should be doubled"
|
64 |
+
assert len(mosaic_boxes) > 0, "Should have some bounding boxes"
|
tools/layer_helper.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
import inspect
|
|
|
2 |
import torch.nn as nn
|
|
|
3 |
from model import module
|
4 |
|
5 |
|
|
|
1 |
import inspect
|
2 |
+
|
3 |
import torch.nn as nn
|
4 |
+
|
5 |
from model import module
|
6 |
|
7 |
|
tools/log_helper.py
CHANGED
@@ -12,6 +12,7 @@ Example:
|
|
12 |
"""
|
13 |
|
14 |
import sys
|
|
|
15 |
from loguru import logger
|
16 |
|
17 |
|
|
|
12 |
"""
|
13 |
|
14 |
import sys
|
15 |
+
|
16 |
from loguru import logger
|
17 |
|
18 |
|
train.py
CHANGED
@@ -1,13 +1,16 @@
|
|
|
|
1 |
from loguru import logger
|
|
|
|
|
2 |
from model.yolo import get_model
|
3 |
from tools.log_helper import custom_logger
|
|
|
4 |
from utils.get_dataset import prepare_dataset
|
5 |
-
import hydra
|
6 |
-
from config.config import Config
|
7 |
|
8 |
|
9 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
10 |
def main(cfg: Config):
|
|
|
11 |
if cfg.download.auto:
|
12 |
prepare_dataset(cfg.download)
|
13 |
|
|
|
1 |
+
import hydra
|
2 |
from loguru import logger
|
3 |
+
|
4 |
+
from config.config import Config
|
5 |
from model.yolo import get_model
|
6 |
from tools.log_helper import custom_logger
|
7 |
+
from utils.dataloader import YoloDataset
|
8 |
from utils.get_dataset import prepare_dataset
|
|
|
|
|
9 |
|
10 |
|
11 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
12 |
def main(cfg: Config):
|
13 |
+
dataset = YoloDataset(cfg)
|
14 |
if cfg.download.auto:
|
15 |
prepare_dataset(cfg.download)
|
16 |
|
utils/data_augment.py
CHANGED
@@ -1,15 +1,16 @@
|
|
1 |
-
from PIL import Image
|
2 |
import numpy as np
|
3 |
import torch
|
|
|
4 |
from torchvision.transforms import functional as TF
|
5 |
-
from torchvision.transforms.functional import
|
6 |
|
7 |
|
8 |
class Compose:
|
9 |
"""Composes several transforms together."""
|
10 |
|
11 |
-
def __init__(self, transforms):
|
12 |
self.transforms = transforms
|
|
|
13 |
|
14 |
for transform in self.transforms:
|
15 |
if hasattr(transform, "set_parent"):
|
@@ -20,9 +21,6 @@ class Compose:
|
|
20 |
image, boxes = transform(image, boxes)
|
21 |
return image, boxes
|
22 |
|
23 |
-
def get_more_data(self):
|
24 |
-
raise NotImplementedError("This method should be overridden by subclass instances!")
|
25 |
-
|
26 |
|
27 |
class RandomHorizontalFlip:
|
28 |
"""Randomly horizontally flips the image along with the bounding boxes."""
|
|
|
|
|
1 |
import numpy as np
|
2 |
import torch
|
3 |
+
from PIL import Image
|
4 |
from torchvision.transforms import functional as TF
|
5 |
+
from torchvision.transforms.functional import to_pil_image, to_tensor
|
6 |
|
7 |
|
8 |
class Compose:
|
9 |
"""Composes several transforms together."""
|
10 |
|
11 |
+
def __init__(self, transforms, image_size: int = 640):
|
12 |
self.transforms = transforms
|
13 |
+
self.image_size = image_size
|
14 |
|
15 |
for transform in self.transforms:
|
16 |
if hasattr(transform, "set_parent"):
|
|
|
21 |
image, boxes = transform(image, boxes)
|
22 |
return image, boxes
|
23 |
|
|
|
|
|
|
|
24 |
|
25 |
class RandomHorizontalFlip:
|
26 |
"""Randomly horizontally flips the image along with the bounding boxes."""
|
utils/dataloader.py
CHANGED
@@ -1,26 +1,35 @@
|
|
1 |
-
from
|
2 |
-
from
|
3 |
|
|
|
4 |
import hydra
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
-
from torch.utils.data import Dataset
|
8 |
from loguru import logger
|
|
|
|
|
9 |
from tqdm.rich import tqdm
|
10 |
-
|
11 |
-
from
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
|
16 |
class YoloDataset(Dataset):
|
17 |
-
def __init__(self,
|
|
|
|
|
18 |
phase_name = dataset_cfg.get(phase, phase)
|
19 |
self.image_size = image_size
|
20 |
|
21 |
-
|
|
|
22 |
self.transform.get_more_data = self.get_more_data
|
23 |
-
self.transform.image_size = self.image_size
|
24 |
self.data = self.load_data(dataset_cfg.path, phase_name)
|
25 |
|
26 |
def load_data(self, dataset_path, phase_name):
|
@@ -129,8 +138,7 @@ class YoloDataset(Dataset):
|
|
129 |
|
130 |
@hydra.main(config_path="../config", config_name="config", version_base=None)
|
131 |
def main(cfg):
|
132 |
-
|
133 |
-
dataset = YoloDataset(cfg.data, transform=transform)
|
134 |
draw_bboxes(*dataset[0])
|
135 |
|
136 |
|
|
|
1 |
+
from os import listdir, path
|
2 |
+
from typing import Union
|
3 |
|
4 |
+
import diskcache as dc
|
5 |
import hydra
|
6 |
import numpy as np
|
7 |
import torch
|
|
|
8 |
from loguru import logger
|
9 |
+
from PIL import Image
|
10 |
+
from torch.utils.data import Dataset
|
11 |
from tqdm.rich import tqdm
|
12 |
+
|
13 |
+
from utils.data_augment import (
|
14 |
+
Compose,
|
15 |
+
MixUp,
|
16 |
+
Mosaic,
|
17 |
+
RandomHorizontalFlip,
|
18 |
+
RandomVerticalFlip,
|
19 |
+
)
|
20 |
+
from utils.drawer import draw_bboxes
|
21 |
|
22 |
|
23 |
class YoloDataset(Dataset):
|
24 |
+
def __init__(self, config: dict, phase: str = "train", image_size: int = 640):
|
25 |
+
dataset_cfg = config.data
|
26 |
+
augment_cfg = config.augmentation
|
27 |
phase_name = dataset_cfg.get(phase, phase)
|
28 |
self.image_size = image_size
|
29 |
|
30 |
+
transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
|
31 |
+
self.transform = Compose(transforms, self.image_size)
|
32 |
self.transform.get_more_data = self.get_more_data
|
|
|
33 |
self.data = self.load_data(dataset_cfg.path, phase_name)
|
34 |
|
35 |
def load_data(self, dataset_path, phase_name):
|
|
|
138 |
|
139 |
@hydra.main(config_path="../config", config_name="config", version_base=None)
|
140 |
def main(cfg):
|
141 |
+
dataset = YoloDataset(cfg)
|
|
|
142 |
draw_bboxes(*dataset[0])
|
143 |
|
144 |
|
utils/get_dataset.py
CHANGED
@@ -2,8 +2,8 @@ import os
|
|
2 |
import zipfile
|
3 |
|
4 |
import hydra
|
5 |
-
from loguru import logger
|
6 |
import requests
|
|
|
7 |
from tqdm.rich import tqdm
|
8 |
|
9 |
|
|
|
2 |
import zipfile
|
3 |
|
4 |
import hydra
|
|
|
5 |
import requests
|
6 |
+
from loguru import logger
|
7 |
from tqdm.rich import tqdm
|
8 |
|
9 |
|