π [Merge] branch 'DATASET' into TEST
Browse files- .gitignore +2 -2
- README.md +1 -1
- config/config.py +20 -0
- config/config.yaml +10 -0
- config/data/augmentation.yaml +2 -0
- config/data/coco.yaml +1 -0
- config/data/download.yaml +17 -0
- model/yolo.py +7 -22
- requirements.txt +9 -0
- tools/layer_helper.py +20 -0
- tools/log_helper.py +23 -0
- train.py +10 -17
- utils/dataargument.py +79 -0
- utils/dataloader.py +144 -0
- utils/drawer.py +32 -0
- utils/get_dataset.py +84 -0
- utils/tools.py +0 -72
.gitignore
CHANGED
@@ -111,8 +111,8 @@ dmypy.json
|
|
111 |
|
112 |
# Machine learning specific folders and symlinks
|
113 |
runs
|
114 |
-
data
|
115 |
-
datasets
|
116 |
|
117 |
# Datasets and model checkpoints
|
118 |
*.pth
|
|
|
111 |
|
112 |
# Machine learning specific folders and symlinks
|
113 |
runs
|
114 |
+
/data
|
115 |
+
/datasets
|
116 |
|
117 |
# Datasets and model checkpoints
|
118 |
*.pth
|
README.md
CHANGED
@@ -3,7 +3,7 @@ An MIT license rewrite of YOLOv9
|
|
3 |
|
4 |
## To-Do Lists
|
5 |
- [ ] Project Setup
|
6 |
-
- [
|
7 |
- [ ] LICENSE
|
8 |
- [ ] README
|
9 |
- [ ] pytests
|
|
|
3 |
|
4 |
## To-Do Lists
|
5 |
- [ ] Project Setup
|
6 |
+
- [X] requirements
|
7 |
- [ ] LICENSE
|
8 |
- [ ] README
|
9 |
- [ ] pytests
|
config/config.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import List, Dict, Union
|
3 |
+
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class Model:
|
7 |
+
anchor: List[List[int]]
|
8 |
+
model: Dict[str, List[Dict[str, Union[Dict, List, int]]]]
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class Download:
|
13 |
+
auto: bool
|
14 |
+
path: str
|
15 |
+
|
16 |
+
|
17 |
+
@dataclass
|
18 |
+
class Config:
|
19 |
+
model: Model
|
20 |
+
download: Download
|
config/config.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
hydra:
|
2 |
+
run:
|
3 |
+
dir: ./runs
|
4 |
+
|
5 |
+
defaults:
|
6 |
+
- data: coco
|
7 |
+
- download: ../data/download
|
8 |
+
- augmentation: ../data/augmentation
|
9 |
+
- model: v7-base
|
10 |
+
- _self_
|
config/data/augmentation.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
RandomHorizontalFlip: 0.5
|
2 |
+
Mosaic: 0.5
|
config/data/coco.yaml
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
path: data/coco
|
config/data/download.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
auto: True
|
2 |
+
path: data/coco
|
3 |
+
images:
|
4 |
+
base_url: http://images.cocodataset.org/zips/
|
5 |
+
datasets:
|
6 |
+
train:
|
7 |
+
file_name: train2017.zip
|
8 |
+
file_num: 118287
|
9 |
+
val:
|
10 |
+
file_name: val2017.zip
|
11 |
+
num_files: 5000
|
12 |
+
test:
|
13 |
+
file_name: test2017.zip
|
14 |
+
num_files: 40670
|
15 |
+
hydra:
|
16 |
+
run:
|
17 |
+
dir: ./runs
|
model/yolo.py
CHANGED
@@ -1,25 +1,10 @@
|
|
1 |
-
import inspect
|
2 |
from typing import Any, Dict, List, Union
|
3 |
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
from loguru import logger
|
7 |
-
|
8 |
-
from
|
9 |
-
from utils.tools import load_model_cfg
|
10 |
-
|
11 |
-
|
12 |
-
def get_layer_map():
|
13 |
-
"""
|
14 |
-
Dynamically generates a dictionary mapping class names to classes,
|
15 |
-
filtering to include only those that are subclasses of nn.Module,
|
16 |
-
ensuring they are relevant neural network layers.
|
17 |
-
"""
|
18 |
-
layer_map = {}
|
19 |
-
for name, obj in inspect.getmembers(module, inspect.isclass):
|
20 |
-
if issubclass(obj, nn.Module) and obj is not nn.Module:
|
21 |
-
layer_map[name] = obj
|
22 |
-
return layer_map
|
23 |
|
24 |
|
25 |
class YOLO(nn.Module):
|
@@ -35,16 +20,15 @@ class YOLO(nn.Module):
|
|
35 |
super(YOLO, self).__init__()
|
36 |
self.nc = model_cfg["nc"]
|
37 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
38 |
-
self.build_model(model_cfg
|
39 |
|
40 |
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
|
41 |
model_list = nn.ModuleList()
|
42 |
output_dim = [3]
|
43 |
layer_indices_by_tag = {}
|
44 |
-
|
45 |
-
for arch_name, arch in model_arch.items():
|
46 |
logger.info(f"ποΈ Building model-{arch_name}")
|
47 |
-
for layer_idx, layer_spec in enumerate(
|
48 |
layer_type, layer_info = next(iter(layer_spec.items()))
|
49 |
layer_args = layer_info.get("args", {})
|
50 |
source = layer_info.get("source", -1)
|
@@ -74,7 +58,7 @@ class YOLO(nn.Module):
|
|
74 |
y = [x]
|
75 |
output = []
|
76 |
for layer in self.model:
|
77 |
-
if
|
78 |
model_input = [y[idx] for idx in layer.source]
|
79 |
else:
|
80 |
model_input = y[layer.source]
|
@@ -113,6 +97,7 @@ def get_model(model_cfg: dict) -> YOLO:
|
|
113 |
Returns:
|
114 |
YOLO: An instance of the model defined by the given configuration.
|
115 |
"""
|
|
|
116 |
model = YOLO(model_cfg)
|
117 |
logger.info("β
Success load model")
|
118 |
return model
|
|
|
|
|
1 |
from typing import Any, Dict, List, Union
|
2 |
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
from loguru import logger
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
from tools.layer_helper import get_layer_map
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
|
10 |
class YOLO(nn.Module):
|
|
|
20 |
super(YOLO, self).__init__()
|
21 |
self.nc = model_cfg["nc"]
|
22 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
23 |
+
self.build_model(model_cfg.model)
|
24 |
|
25 |
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
|
26 |
model_list = nn.ModuleList()
|
27 |
output_dim = [3]
|
28 |
layer_indices_by_tag = {}
|
29 |
+
for arch_name in model_arch:
|
|
|
30 |
logger.info(f"ποΈ Building model-{arch_name}")
|
31 |
+
for layer_idx, layer_spec in enumerate(model_arch[arch_name], start=1):
|
32 |
layer_type, layer_info = next(iter(layer_spec.items()))
|
33 |
layer_args = layer_info.get("args", {})
|
34 |
source = layer_info.get("source", -1)
|
|
|
58 |
y = [x]
|
59 |
output = []
|
60 |
for layer in self.model:
|
61 |
+
if OmegaConf.is_list(layer.source):
|
62 |
model_input = [y[idx] for idx in layer.source]
|
63 |
else:
|
64 |
model_input = y[layer.source]
|
|
|
97 |
Returns:
|
98 |
YOLO: An instance of the model defined by the given configuration.
|
99 |
"""
|
100 |
+
OmegaConf.set_struct(model_cfg, False)
|
101 |
model = YOLO(model_cfg)
|
102 |
logger.info("β
Success load model")
|
103 |
return model
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
hydra-core
|
2 |
+
loguru
|
3 |
+
numpy
|
4 |
+
pytest
|
5 |
+
pyyaml
|
6 |
+
requests
|
7 |
+
rich
|
8 |
+
torch
|
9 |
+
tqdm
|
tools/layer_helper.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import torch.nn as nn
|
3 |
+
from model import module
|
4 |
+
|
5 |
+
|
6 |
+
def auto_pad():
|
7 |
+
raise NotImplementedError
|
8 |
+
|
9 |
+
|
10 |
+
def get_layer_map():
|
11 |
+
"""
|
12 |
+
Dynamically generates a dictionary mapping class names to classes,
|
13 |
+
filtering to include only those that are subclasses of nn.Module,
|
14 |
+
ensuring they are relevant neural network layers.
|
15 |
+
"""
|
16 |
+
layer_map = {}
|
17 |
+
for name, obj in inspect.getmembers(module, inspect.isclass):
|
18 |
+
if issubclass(obj, nn.Module) and obj is not nn.Module:
|
19 |
+
layer_map[name] = obj
|
20 |
+
return layer_map
|
tools/log_helper.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module for initializing logging tools used in machine learning and data processing.
|
3 |
+
Supports integration with Weights & Biases (wandb), Loguru, TensorBoard, and other
|
4 |
+
logging frameworks as needed.
|
5 |
+
|
6 |
+
This setup ensures consistent logging across various platforms, facilitating
|
7 |
+
effective monitoring and debugging.
|
8 |
+
|
9 |
+
Example:
|
10 |
+
from tools.logger import custom_logger
|
11 |
+
custom_logger()
|
12 |
+
"""
|
13 |
+
|
14 |
+
import sys
|
15 |
+
from loguru import logger
|
16 |
+
|
17 |
+
|
18 |
+
def custom_logger():
|
19 |
+
logger.remove()
|
20 |
+
logger.add(
|
21 |
+
sys.stderr,
|
22 |
+
format="<green>{time:MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <level>{message}</level>",
|
23 |
+
)
|
train.py
CHANGED
@@ -1,26 +1,19 @@
|
|
1 |
-
import argparse
|
2 |
from loguru import logger
|
3 |
from model.yolo import get_model
|
4 |
-
from
|
|
|
|
|
|
|
5 |
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
10 |
|
11 |
-
|
12 |
-
argparse.Namespace: The command-line arguments object with 'config' attribute.
|
13 |
-
"""
|
14 |
-
parser = argparse.ArgumentParser(description="Load a YOLO model configuration and display the model.")
|
15 |
-
parser.add_argument(
|
16 |
-
"--model-config", type=str, default="v7-base", help="Name or path to the model configuration file."
|
17 |
-
)
|
18 |
-
return parser.parse_args()
|
19 |
|
20 |
|
21 |
if __name__ == "__main__":
|
22 |
custom_logger()
|
23 |
-
|
24 |
-
model_cfg = load_model_cfg(args.model_config)
|
25 |
-
model = get_model(model_cfg)
|
26 |
-
logger.info("Success load model")
|
|
|
|
|
1 |
from loguru import logger
|
2 |
from model.yolo import get_model
|
3 |
+
from tools.log_helper import custom_logger
|
4 |
+
from utils.get_dataset import prepare_dataset
|
5 |
+
import hydra
|
6 |
+
from config.config import Config
|
7 |
|
8 |
|
9 |
+
@hydra.main(config_path="config", config_name="config", version_base=None)
|
10 |
+
def main(cfg: Config):
|
11 |
+
if cfg.download.auto:
|
12 |
+
prepare_dataset(cfg.download)
|
13 |
|
14 |
+
model = get_model(cfg.model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
if __name__ == "__main__":
|
18 |
custom_logger()
|
19 |
+
main()
|
|
|
|
|
|
utils/dataargument.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torchvision.transforms import functional as TF
|
5 |
+
|
6 |
+
|
7 |
+
class Compose:
|
8 |
+
"""Composes several transforms together."""
|
9 |
+
|
10 |
+
def __init__(self, transforms):
|
11 |
+
self.transforms = transforms
|
12 |
+
|
13 |
+
for transform in self.transforms:
|
14 |
+
if hasattr(transform, "set_parent"):
|
15 |
+
transform.set_parent(self)
|
16 |
+
|
17 |
+
def __call__(self, image, boxes):
|
18 |
+
for transform in self.transforms:
|
19 |
+
image, boxes = transform(image, boxes)
|
20 |
+
return image, boxes
|
21 |
+
|
22 |
+
def get_more_data(self):
|
23 |
+
raise NotImplementedError("This method should be overridden by subclass instances!")
|
24 |
+
|
25 |
+
|
26 |
+
class RandomHorizontalFlip:
|
27 |
+
"""Randomly horizontally flips the image along with the bounding boxes."""
|
28 |
+
|
29 |
+
def __init__(self, prob=0.5):
|
30 |
+
self.prob = prob
|
31 |
+
|
32 |
+
def __call__(self, image, boxes):
|
33 |
+
if torch.rand(1) < self.prob:
|
34 |
+
image = TF.hflip(image)
|
35 |
+
boxes[:, [1, 3]] = 1 - boxes[:, [3, 1]]
|
36 |
+
return image, boxes
|
37 |
+
|
38 |
+
|
39 |
+
class Mosaic:
|
40 |
+
"""Applies the Mosaic augmentation to a batch of images and their corresponding boxes."""
|
41 |
+
|
42 |
+
def __init__(self, prob=0.5):
|
43 |
+
self.prob = prob
|
44 |
+
self.parent = None
|
45 |
+
|
46 |
+
def set_parent(self, parent):
|
47 |
+
self.parent = parent
|
48 |
+
|
49 |
+
def __call__(self, image, boxes):
|
50 |
+
if torch.rand(1) >= self.prob:
|
51 |
+
return image, boxes
|
52 |
+
|
53 |
+
assert self.parent is not None, "Parent is not set. Mosaic cannot retrieve image size."
|
54 |
+
|
55 |
+
img_sz = self.parent.image_size # Assuming `image_size` is defined in parent
|
56 |
+
more_data = self.parent.get_more_data(3) # get 3 more images randomly
|
57 |
+
|
58 |
+
data = [(image, boxes)] + more_data
|
59 |
+
mosaic_image = Image.new("RGB", (2 * img_sz, 2 * img_sz))
|
60 |
+
vectors = np.array([(-1, -1), (0, -1), (-1, 0), (0, 0)])
|
61 |
+
center = np.array([img_sz, img_sz])
|
62 |
+
all_labels = []
|
63 |
+
|
64 |
+
for (image, boxes), vector in zip(data, vectors):
|
65 |
+
this_w, this_h = image.size
|
66 |
+
coord = tuple(center + vector * np.array([this_w, this_h]))
|
67 |
+
|
68 |
+
mosaic_image.paste(image, coord)
|
69 |
+
xmin, ymin, xmax, ymax = boxes[:, 1], boxes[:, 2], boxes[:, 3], boxes[:, 4]
|
70 |
+
xmin = (xmin * this_w + coord[0]) / (2 * img_sz)
|
71 |
+
xmax = (xmax * this_w + coord[0]) / (2 * img_sz)
|
72 |
+
ymin = (ymin * this_h + coord[1]) / (2 * img_sz)
|
73 |
+
ymax = (ymax * this_h + coord[1]) / (2 * img_sz)
|
74 |
+
|
75 |
+
adjusted_boxes = torch.stack([boxes[:, 0], xmin, ymin, xmax, ymax], dim=1)
|
76 |
+
all_labels.append(adjusted_boxes)
|
77 |
+
|
78 |
+
all_labels = torch.cat(all_labels, dim=0)
|
79 |
+
return mosaic_image, all_labels
|
utils/dataloader.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from os import path, listdir
|
3 |
+
|
4 |
+
import hydra
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
from loguru import logger
|
9 |
+
from tqdm.rich import tqdm
|
10 |
+
import diskcache as dc
|
11 |
+
from typing import Union
|
12 |
+
from drawer import draw_bboxes
|
13 |
+
from dataargument import Compose, RandomHorizontalFlip, Mosaic
|
14 |
+
|
15 |
+
|
16 |
+
class YoloDataset(Dataset):
|
17 |
+
def __init__(self, dataset_cfg: dict, phase: str = "train", image_size: int = 640, transform=None):
|
18 |
+
phase_name = dataset_cfg.get(phase, phase)
|
19 |
+
self.image_size = image_size
|
20 |
+
|
21 |
+
self.transform = transform
|
22 |
+
self.transform.get_more_data = self.get_more_data
|
23 |
+
self.transform.image_size = self.image_size
|
24 |
+
self.data = self.load_data(dataset_cfg.path, phase_name)
|
25 |
+
|
26 |
+
def load_data(self, dataset_path, phase_name):
|
27 |
+
"""
|
28 |
+
Loads data from a cache or generates a new cache for a specific dataset phase.
|
29 |
+
|
30 |
+
Parameters:
|
31 |
+
dataset_path (str): The root path to the dataset directory.
|
32 |
+
phase_name (str): The specific phase of the dataset (e.g., 'train', 'test') to load or generate data for.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
dict: The loaded data from the cache for the specified phase.
|
36 |
+
"""
|
37 |
+
cache_path = path.join(dataset_path, ".cache")
|
38 |
+
cache = dc.Cache(cache_path)
|
39 |
+
data = cache.get(phase_name)
|
40 |
+
|
41 |
+
if data is None:
|
42 |
+
logger.info("Generating {} cache", phase_name)
|
43 |
+
images_path = path.join(dataset_path, phase_name, "images")
|
44 |
+
labels_path = path.join(dataset_path, phase_name, "labels")
|
45 |
+
data = self.filter_data(images_path, labels_path)
|
46 |
+
cache[phase_name] = data
|
47 |
+
|
48 |
+
cache.close()
|
49 |
+
logger.info("Loaded {} cache", phase_name)
|
50 |
+
data = cache[phase_name]
|
51 |
+
return data
|
52 |
+
|
53 |
+
def filter_data(self, images_path: str, labels_path: str) -> list:
|
54 |
+
"""
|
55 |
+
Filters and collects dataset information by pairing images with their corresponding labels.
|
56 |
+
|
57 |
+
Parameters:
|
58 |
+
images_path (str): Path to the directory containing image files.
|
59 |
+
labels_path (str): Path to the directory containing label files.
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
list: A list of tuples, each containing the path to an image file and its associated labels as a tensor.
|
63 |
+
"""
|
64 |
+
data = []
|
65 |
+
valid_inputs = 0
|
66 |
+
images_list = sorted(listdir(images_path))
|
67 |
+
for image_name in tqdm(images_list, desc="Filtering data"):
|
68 |
+
if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
|
69 |
+
continue
|
70 |
+
|
71 |
+
img_path = path.join(images_path, image_name)
|
72 |
+
base_name, _ = path.splitext(image_name)
|
73 |
+
label_path = path.join(labels_path, f"{base_name}.txt")
|
74 |
+
|
75 |
+
if path.isfile(label_path):
|
76 |
+
labels = self.load_valid_labels(label_path)
|
77 |
+
if labels is not None:
|
78 |
+
data.append((img_path, labels))
|
79 |
+
valid_inputs += 1
|
80 |
+
|
81 |
+
logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
|
82 |
+
return data
|
83 |
+
|
84 |
+
def load_valid_labels(self, label_path: str) -> Union[torch.Tensor, None]:
|
85 |
+
"""
|
86 |
+
Loads and validates bounding box data is [0, 1] from a label file.
|
87 |
+
|
88 |
+
Parameters:
|
89 |
+
label_path (str): The filepath to the label file containing bounding box data.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
torch.Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None.
|
93 |
+
"""
|
94 |
+
bboxes = []
|
95 |
+
with open(label_path, "r") as file:
|
96 |
+
for line in file:
|
97 |
+
parts = list(map(float, line.strip().split()))
|
98 |
+
cls = parts[0]
|
99 |
+
points = np.array(parts[1:]).reshape(-1, 2)
|
100 |
+
valid_points = points[(points >= 0) & (points <= 1)].reshape(-1, 2)
|
101 |
+
if valid_points.size > 1:
|
102 |
+
bbox = torch.tensor([cls, *valid_points.min(axis=0), *valid_points.max(axis=0)])
|
103 |
+
bboxes.append(bbox)
|
104 |
+
|
105 |
+
if bboxes:
|
106 |
+
return torch.stack(bboxes)
|
107 |
+
else:
|
108 |
+
logger.warning("No valid BBox in {}", label_path)
|
109 |
+
return None
|
110 |
+
|
111 |
+
def get_data(self, idx):
|
112 |
+
img_path, bboxes = self.data[idx]
|
113 |
+
img = Image.open(img_path).convert("RGB")
|
114 |
+
return img, bboxes
|
115 |
+
|
116 |
+
def get_more_data(self, num: int = 1):
|
117 |
+
indices = torch.randint(0, len(self), (num,))
|
118 |
+
return [self.get_data(idx) for idx in indices]
|
119 |
+
|
120 |
+
def __getitem__(self, idx) -> Union[Image.Image, torch.Tensor]:
|
121 |
+
img, bboxes = self.get_data(idx)
|
122 |
+
if self.transform:
|
123 |
+
img, bboxes = self.transform(img, bboxes)
|
124 |
+
return img, bboxes
|
125 |
+
|
126 |
+
def __len__(self) -> int:
|
127 |
+
return len(self.data)
|
128 |
+
|
129 |
+
|
130 |
+
@hydra.main(config_path="../config", config_name="config", version_base=None)
|
131 |
+
def main(cfg):
|
132 |
+
transform = Compose([eval(aug)(prob) for aug, prob in cfg.augmentation.items()])
|
133 |
+
dataset = YoloDataset(cfg.data, transform=transform)
|
134 |
+
draw_bboxes(*dataset[0])
|
135 |
+
|
136 |
+
|
137 |
+
if __name__ == "__main__":
|
138 |
+
import sys
|
139 |
+
|
140 |
+
sys.path.append("./")
|
141 |
+
from tools.log_helper import custom_logger
|
142 |
+
|
143 |
+
custom_logger()
|
144 |
+
main()
|
utils/drawer.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image, ImageDraw, ImageFont
|
2 |
+
|
3 |
+
|
4 |
+
def draw_bboxes(img, bboxes):
|
5 |
+
"""
|
6 |
+
Draw bounding boxes on an image.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
- image_path (str): Path to the image file.
|
10 |
+
- bboxes (list of lists/tuples): Bounding boxes with [x_min, y_min, x_max, y_max, class_id].
|
11 |
+
"""
|
12 |
+
# Load an image
|
13 |
+
draw = ImageDraw.Draw(img)
|
14 |
+
|
15 |
+
# Font for class_id (optional)
|
16 |
+
try:
|
17 |
+
font = ImageFont.truetype("arial.ttf", 30)
|
18 |
+
except IOError:
|
19 |
+
font = ImageFont.load_default(30)
|
20 |
+
width, height = img.size
|
21 |
+
|
22 |
+
for bbox in bboxes:
|
23 |
+
class_id, x_min, y_min, x_max, y_max = bbox
|
24 |
+
x_min = x_min * width
|
25 |
+
x_max = x_max * width
|
26 |
+
y_min = y_min * height
|
27 |
+
y_max = y_max * height
|
28 |
+
shape = [(x_min, y_min), (x_max, y_max)]
|
29 |
+
draw.rectangle(shape, outline="red", width=2)
|
30 |
+
draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
|
31 |
+
|
32 |
+
img.save("output.jpg")
|
utils/get_dataset.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import zipfile
|
3 |
+
|
4 |
+
import hydra
|
5 |
+
from loguru import logger
|
6 |
+
import requests
|
7 |
+
from tqdm.rich import tqdm
|
8 |
+
|
9 |
+
|
10 |
+
def download_file(url, dest_path):
|
11 |
+
"""
|
12 |
+
Downloads a file from a specified URL to a destination path with progress logging.
|
13 |
+
"""
|
14 |
+
logger.info(f"Downloading {os.path.basename(dest_path)}...")
|
15 |
+
with requests.get(url, stream=True) as r:
|
16 |
+
r.raise_for_status()
|
17 |
+
total_length = int(r.headers.get("content-length", 0))
|
18 |
+
with open(dest_path, "wb") as f, tqdm(
|
19 |
+
total=total_length, unit="iB", unit_scale=True, desc=os.path.basename(dest_path), leave=True
|
20 |
+
) as bar:
|
21 |
+
for chunk in r.iter_content(chunk_size=1024 * 1024):
|
22 |
+
f.write(chunk)
|
23 |
+
bar.update(len(chunk))
|
24 |
+
logger.info("Download complete!")
|
25 |
+
|
26 |
+
|
27 |
+
def unzip_file(zip_path, extract_to):
|
28 |
+
"""
|
29 |
+
Unzips a ZIP file to a specified directory.
|
30 |
+
"""
|
31 |
+
logger.info(f"Unzipping {os.path.basename(zip_path)}...")
|
32 |
+
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
33 |
+
zip_ref.extractall(extract_to)
|
34 |
+
os.remove(zip_path)
|
35 |
+
logger.info(f"Removed {zip_path}")
|
36 |
+
|
37 |
+
|
38 |
+
def check_files(directory, expected_count):
|
39 |
+
"""
|
40 |
+
Checks if the specified directory has the expected number of files.
|
41 |
+
"""
|
42 |
+
num_files = len([name for name in os.listdir(directory) if os.path.isfile(os.path.join(directory, name))])
|
43 |
+
return num_files == expected_count
|
44 |
+
|
45 |
+
|
46 |
+
@hydra.main(config_path="../config/data", config_name="download", version_base=None)
|
47 |
+
def prepare_dataset(download_cfg):
|
48 |
+
data_dir = download_cfg.path
|
49 |
+
base_url = download_cfg.images.base_url
|
50 |
+
datasets = download_cfg.images.datasets
|
51 |
+
|
52 |
+
for dataset_type in datasets:
|
53 |
+
file_name, expected_files = datasets[dataset_type].values()
|
54 |
+
url = f"{base_url}{file_name}"
|
55 |
+
local_zip_path = os.path.join(data_dir, file_name)
|
56 |
+
extract_to = os.path.join(data_dir, dataset_type, "images")
|
57 |
+
|
58 |
+
# Ensure the extraction directory exists
|
59 |
+
os.makedirs(extract_to, exist_ok=True)
|
60 |
+
|
61 |
+
# Check if the correct number of files exists
|
62 |
+
if check_files(extract_to, expected_files):
|
63 |
+
logger.info(f"Dataset {dataset_type} already verified.")
|
64 |
+
continue
|
65 |
+
|
66 |
+
if os.path.exists(local_zip_path):
|
67 |
+
logger.info(f"Dataset {dataset_type} already downloaded.")
|
68 |
+
else:
|
69 |
+
download_file(url, local_zip_path)
|
70 |
+
|
71 |
+
unzip_file(local_zip_path, extract_to)
|
72 |
+
|
73 |
+
print(os.path.exists(local_zip_path), check_files(extract_to, expected_files))
|
74 |
+
|
75 |
+
# Additional verification post extraction
|
76 |
+
if not check_files(extract_to, expected_files):
|
77 |
+
logger.error(f"Error in verifying the {dataset_type} dataset after extraction.")
|
78 |
+
|
79 |
+
|
80 |
+
if __name__ == "__main__":
|
81 |
+
from tools.log_helper import custom_logger
|
82 |
+
|
83 |
+
custom_logger()
|
84 |
+
prepare_dataset()
|
utils/tools.py
DELETED
@@ -1,72 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import yaml
|
4 |
-
from loguru import logger
|
5 |
-
from typing import Dict, Any
|
6 |
-
|
7 |
-
|
8 |
-
def complete_path(file_name: str = "v7-base.yaml") -> str:
|
9 |
-
"""
|
10 |
-
Ensures the path to a model configuration is a existing file
|
11 |
-
|
12 |
-
Parameters:
|
13 |
-
file_name (str): The filename or path, with default 'v7-base.yaml'.
|
14 |
-
|
15 |
-
Returns:
|
16 |
-
str: A complete path with necessary prefix and extension.
|
17 |
-
"""
|
18 |
-
# Ensure the file has the '.yaml' extension if missing
|
19 |
-
if not file_name.endswith(".yaml"):
|
20 |
-
file_name += ".yaml"
|
21 |
-
|
22 |
-
# Add folder prefix if only the filename is provided
|
23 |
-
if os.path.dirname(file_name) == "":
|
24 |
-
file_name = os.path.join("./config/model", file_name)
|
25 |
-
|
26 |
-
return file_name
|
27 |
-
|
28 |
-
|
29 |
-
def load_model_cfg(file_path: str) -> Dict[str, Any]:
|
30 |
-
"""
|
31 |
-
Read a YAML configuration file, ensure necessary keys are present, and return its content as a dictionary.
|
32 |
-
|
33 |
-
Args:
|
34 |
-
file_path (str): The path to the YAML configuration file.
|
35 |
-
|
36 |
-
Returns:
|
37 |
-
Dict[str, Any]: The contents of the YAML file as a dictionary.
|
38 |
-
|
39 |
-
Raises:
|
40 |
-
FileNotFoundError: If the YAML file cannot be found.
|
41 |
-
yaml.YAMLError: If there is an error parsing the YAML file.
|
42 |
-
"""
|
43 |
-
file_path = complete_path(file_path)
|
44 |
-
try:
|
45 |
-
with open(file_path, "r") as file:
|
46 |
-
model_cfg = yaml.safe_load(file) or {}
|
47 |
-
|
48 |
-
# Check for required keys and set defaults if not present
|
49 |
-
if "nc" not in model_cfg:
|
50 |
-
model_cfg["nc"] = 80
|
51 |
-
logger.warning("'nc' not found in the YAML file. Setting default 'nc' to 80.")
|
52 |
-
|
53 |
-
if "model" not in model_cfg:
|
54 |
-
logger.error("'model' is missing in the configuration file.")
|
55 |
-
raise ValueError("Missing required key: 'model'")
|
56 |
-
|
57 |
-
return model_cfg
|
58 |
-
|
59 |
-
except FileNotFoundError:
|
60 |
-
logger.error(f"YAML file not found: {file_path}")
|
61 |
-
raise
|
62 |
-
except yaml.YAMLError as e:
|
63 |
-
logger.error(f"Error parsing YAML file: {e}")
|
64 |
-
raise
|
65 |
-
|
66 |
-
|
67 |
-
def custom_logger():
|
68 |
-
logger.remove()
|
69 |
-
logger.add(
|
70 |
-
sys.stderr,
|
71 |
-
format="<green>{time:MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <level>{message}</level>",
|
72 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|