π [Merge] branch 'main' into DATASET
Browse files- config/config.py +0 -91
- config/config.yaml +0 -11
- config/data/augmentation.yaml +0 -3
- config/data/download.yaml +0 -21
- config/hyper/default.yaml +0 -19
- requirements.txt +2 -0
- tests/test_model/test_yolo.py +7 -7
- tests/test_utils/test_dataaugment.py +5 -2
- tests/test_utils/test_loss.py +39 -0
- train.py +0 -29
- utils/converter_json2txt.py +0 -86
- utils/data_augment.py +0 -125
- utils/dataloader.py +0 -206
- utils/drawer.py +0 -41
- utils/get_dataset.py +0 -84
- utils/loss.py +0 -2
- yolo/config/config.py +17 -1
- yolo/config/hyper/default.yaml +16 -0
- yolo/config/model/v7-base.yaml +4 -0
- yolo/tools/bbox_helper.py +251 -0
- yolo/utils/loss.py +164 -0
config/config.py
DELETED
@@ -1,91 +0,0 @@
|
|
1 |
-
from dataclasses import dataclass
|
2 |
-
from typing import Dict, List, Union
|
3 |
-
|
4 |
-
|
5 |
-
@dataclass
|
6 |
-
class Model:
|
7 |
-
anchor: List[List[int]]
|
8 |
-
model: Dict[str, List[Dict[str, Union[Dict, List, int]]]]
|
9 |
-
|
10 |
-
|
11 |
-
@dataclass
|
12 |
-
class Download:
|
13 |
-
auto: bool
|
14 |
-
path: str
|
15 |
-
|
16 |
-
|
17 |
-
@dataclass
|
18 |
-
class DataLoaderConfig:
|
19 |
-
batch_size: int
|
20 |
-
shuffle: bool
|
21 |
-
num_workers: int
|
22 |
-
pin_memory: bool
|
23 |
-
|
24 |
-
|
25 |
-
@dataclass
|
26 |
-
class OptimizerArgs:
|
27 |
-
lr: float
|
28 |
-
weight_decay: float
|
29 |
-
|
30 |
-
|
31 |
-
@dataclass
|
32 |
-
class OptimizerConfig:
|
33 |
-
type: str
|
34 |
-
args: OptimizerArgs
|
35 |
-
|
36 |
-
|
37 |
-
@dataclass
|
38 |
-
class SchedulerArgs:
|
39 |
-
step_size: int
|
40 |
-
gamma: float
|
41 |
-
|
42 |
-
|
43 |
-
@dataclass
|
44 |
-
class SchedulerConfig:
|
45 |
-
type: str
|
46 |
-
args: SchedulerArgs
|
47 |
-
|
48 |
-
|
49 |
-
@dataclass
|
50 |
-
class EMAConfig:
|
51 |
-
enabled: bool
|
52 |
-
decay: float
|
53 |
-
|
54 |
-
|
55 |
-
@dataclass
|
56 |
-
class TrainConfig:
|
57 |
-
optimizer: OptimizerConfig
|
58 |
-
scheduler: SchedulerConfig
|
59 |
-
ema: EMAConfig
|
60 |
-
|
61 |
-
|
62 |
-
@dataclass
|
63 |
-
class HyperConfig:
|
64 |
-
data: DataLoaderConfig
|
65 |
-
train: TrainConfig
|
66 |
-
|
67 |
-
|
68 |
-
@dataclass
|
69 |
-
class Dataset:
|
70 |
-
file_name: str
|
71 |
-
num_files: int
|
72 |
-
|
73 |
-
|
74 |
-
@dataclass
|
75 |
-
class Datasets:
|
76 |
-
base_url: str
|
77 |
-
images: Dict[str, Dataset]
|
78 |
-
|
79 |
-
|
80 |
-
@dataclass
|
81 |
-
class Download:
|
82 |
-
auto: bool
|
83 |
-
save_path: str
|
84 |
-
datasets: Datasets
|
85 |
-
|
86 |
-
|
87 |
-
@dataclass
|
88 |
-
class Config:
|
89 |
-
model: Model
|
90 |
-
download: Download
|
91 |
-
hyper: HyperConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/config.yaml
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
hydra:
|
2 |
-
run:
|
3 |
-
dir: ./runs
|
4 |
-
|
5 |
-
defaults:
|
6 |
-
- data: coco
|
7 |
-
- download: ../data/download
|
8 |
-
- augmentation: ../data/augmentation
|
9 |
-
- model: v7-base
|
10 |
-
- hyper: default
|
11 |
-
- _self_
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/data/augmentation.yaml
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
Mosaic: 1
|
2 |
-
# MixUp: 1
|
3 |
-
HorizontalFlip: 0.5
|
|
|
|
|
|
|
|
config/data/download.yaml
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
auto: True
|
2 |
-
save_path: data/coco
|
3 |
-
datasets:
|
4 |
-
images:
|
5 |
-
base_url: http://images.cocodataset.org/zips/
|
6 |
-
train2017:
|
7 |
-
file_name: train2017
|
8 |
-
file_num: 118287
|
9 |
-
val2017:
|
10 |
-
file_name: val2017
|
11 |
-
file_num: 5000
|
12 |
-
test2017:
|
13 |
-
file_name: test2017
|
14 |
-
file_num: 40670
|
15 |
-
annotations:
|
16 |
-
base_url: http://images.cocodataset.org/annotations/
|
17 |
-
annotations:
|
18 |
-
file_name: annotations_trainval2017
|
19 |
-
hydra:
|
20 |
-
run:
|
21 |
-
dir: ./runs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/hyper/default.yaml
DELETED
@@ -1,19 +0,0 @@
|
|
1 |
-
data:
|
2 |
-
batch_size: 4
|
3 |
-
shuffle: True
|
4 |
-
num_workers: 4
|
5 |
-
pin_memory: True
|
6 |
-
train:
|
7 |
-
optimizer:
|
8 |
-
type: Adam
|
9 |
-
args:
|
10 |
-
lr: 0.001
|
11 |
-
weight_decay: 0.0001
|
12 |
-
scheduler:
|
13 |
-
type: StepLR
|
14 |
-
args:
|
15 |
-
step_size: 10
|
16 |
-
gamma: 0.1
|
17 |
-
ema:
|
18 |
-
enabled: true
|
19 |
-
decay: 0.995
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
hydra-core
|
2 |
loguru
|
3 |
numpy
|
|
|
1 |
+
diskcache
|
2 |
+
einops
|
3 |
hydra-core
|
4 |
loguru
|
5 |
numpy
|
tests/test_model/test_yolo.py
CHANGED
@@ -1,20 +1,20 @@
|
|
1 |
import sys
|
|
|
2 |
|
3 |
-
import pytest
|
4 |
import torch
|
5 |
from hydra import compose, initialize
|
6 |
-
from
|
7 |
-
from omegaconf import DictConfig, OmegaConf
|
8 |
|
9 |
-
|
10 |
-
|
11 |
|
12 |
-
|
|
|
|
|
13 |
config_name = "v7-base"
|
14 |
|
15 |
|
16 |
def test_build_model():
|
17 |
-
|
18 |
with initialize(config_path=config_path, version_base=None):
|
19 |
model_cfg = compose(config_name=config_name)
|
20 |
OmegaConf.set_struct(model_cfg, False)
|
|
|
1 |
import sys
|
2 |
+
from pathlib import Path
|
3 |
|
|
|
4 |
import torch
|
5 |
from hydra import compose, initialize
|
6 |
+
from omegaconf import OmegaConf
|
|
|
7 |
|
8 |
+
project_root = Path(__file__).resolve().parent.parent.parent
|
9 |
+
sys.path.append(str(project_root))
|
10 |
|
11 |
+
from yolo.model.yolo import YOLO, get_model
|
12 |
+
|
13 |
+
config_path = "../../yolo/config/model"
|
14 |
config_name = "v7-base"
|
15 |
|
16 |
|
17 |
def test_build_model():
|
|
|
18 |
with initialize(config_path=config_path, version_base=None):
|
19 |
model_cfg = compose(config_name=config_name)
|
20 |
OmegaConf.set_struct(model_cfg, False)
|
tests/test_utils/test_dataaugment.py
CHANGED
@@ -1,12 +1,15 @@
|
|
1 |
import sys
|
|
|
2 |
|
3 |
import pytest
|
4 |
import torch
|
5 |
from PIL import Image
|
6 |
from torchvision.transforms import functional as TF
|
7 |
|
8 |
-
|
9 |
-
|
|
|
|
|
10 |
|
11 |
|
12 |
def test_horizontal_flip():
|
|
|
1 |
import sys
|
2 |
+
from pathlib import Path
|
3 |
|
4 |
import pytest
|
5 |
import torch
|
6 |
from PIL import Image
|
7 |
from torchvision.transforms import functional as TF
|
8 |
|
9 |
+
project_root = Path(__file__).resolve().parent.parent.parent
|
10 |
+
sys.path.append(str(project_root))
|
11 |
+
|
12 |
+
from yolo.utils.data_augment import Compose, HorizontalFlip, Mosaic, VerticalFlip
|
13 |
|
14 |
|
15 |
def test_horizontal_flip():
|
tests/test_utils/test_loss.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import pytest
|
5 |
+
import torch
|
6 |
+
from hydra import compose, initialize
|
7 |
+
|
8 |
+
project_root = Path(__file__).resolve().parent.parent.parent
|
9 |
+
sys.path.append(str(project_root))
|
10 |
+
|
11 |
+
from yolo.utils.loss import YOLOLoss
|
12 |
+
|
13 |
+
|
14 |
+
@pytest.fixture
|
15 |
+
def cfg():
|
16 |
+
with initialize(config_path="../../yolo/config", version_base=None):
|
17 |
+
cfg = compose(config_name="config")
|
18 |
+
return cfg
|
19 |
+
|
20 |
+
|
21 |
+
@pytest.fixture
|
22 |
+
def loss_function(cfg) -> YOLOLoss:
|
23 |
+
return YOLOLoss(cfg)
|
24 |
+
|
25 |
+
|
26 |
+
@pytest.fixture
|
27 |
+
def data():
|
28 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
29 |
+
targets = torch.zeros(20, 6, device=device)
|
30 |
+
predicts = [[torch.zeros(1, 144, 80 // i, 80 // i, device=device) for i in [1, 2, 4]] for _ in range(2)]
|
31 |
+
return predicts, targets
|
32 |
+
|
33 |
+
|
34 |
+
def test_yolo_loss(loss_function, data):
|
35 |
+
predicts, targets = data
|
36 |
+
loss_iou, loss_dfl, loss_cls = loss_function(predicts, targets)
|
37 |
+
assert torch.isnan(loss_iou)
|
38 |
+
assert torch.isnan(loss_dfl)
|
39 |
+
assert torch.isinf(loss_cls)
|
train.py
DELETED
@@ -1,29 +0,0 @@
|
|
1 |
-
import hydra
|
2 |
-
import torch
|
3 |
-
from loguru import logger
|
4 |
-
|
5 |
-
from config.config import Config
|
6 |
-
from model.yolo import get_model
|
7 |
-
from tools.log_helper import custom_logger
|
8 |
-
from tools.trainer import Trainer
|
9 |
-
from utils.dataloader import get_dataloader
|
10 |
-
from utils.get_dataset import prepare_dataset
|
11 |
-
|
12 |
-
|
13 |
-
@hydra.main(config_path="config", config_name="config", version_base=None)
|
14 |
-
def main(cfg: Config):
|
15 |
-
if cfg.download.auto:
|
16 |
-
prepare_dataset(cfg.download)
|
17 |
-
|
18 |
-
dataloader = get_dataloader(cfg)
|
19 |
-
model = get_model(cfg.model)
|
20 |
-
# TODO: get_device or rank, for DDP mode
|
21 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
22 |
-
|
23 |
-
trainer = Trainer(model, cfg.hyper.train, device)
|
24 |
-
trainer.train(dataloader, 10)
|
25 |
-
|
26 |
-
|
27 |
-
if __name__ == "__main__":
|
28 |
-
custom_logger()
|
29 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/converter_json2txt.py
DELETED
@@ -1,86 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
import os
|
3 |
-
from typing import Dict, List, Optional
|
4 |
-
|
5 |
-
from tqdm import tqdm
|
6 |
-
|
7 |
-
|
8 |
-
def discretize_categories(categories: List[Dict[str, int]]) -> Dict[int, int]:
|
9 |
-
"""
|
10 |
-
Maps each unique 'id' in the list of category dictionaries to a sequential integer index.
|
11 |
-
Indices are assigned based on the sorted 'id' values.
|
12 |
-
"""
|
13 |
-
sorted_categories = sorted(categories, key=lambda category: category["id"])
|
14 |
-
return {category["id"]: index for index, category in enumerate(sorted_categories)}
|
15 |
-
|
16 |
-
|
17 |
-
def process_annotations(
|
18 |
-
image_annotations: Dict[int, List[Dict]],
|
19 |
-
image_info_dict: Dict[int, tuple],
|
20 |
-
output_dir: str,
|
21 |
-
id_to_idx: Optional[Dict[int, int]] = None,
|
22 |
-
) -> None:
|
23 |
-
"""
|
24 |
-
Process and save annotations to files, with option to remap category IDs.
|
25 |
-
"""
|
26 |
-
for image_id, annotations in tqdm(image_annotations.items(), desc="Processing annotations"):
|
27 |
-
file_path = os.path.join(output_dir, f"{image_id:0>12}.txt")
|
28 |
-
if not annotations:
|
29 |
-
continue
|
30 |
-
with open(file_path, "w") as file:
|
31 |
-
for annotation in annotations:
|
32 |
-
process_annotation(annotation, image_info_dict[image_id], id_to_idx, file)
|
33 |
-
|
34 |
-
|
35 |
-
def process_annotation(annotation: Dict, image_dims: tuple, id_to_idx: Optional[Dict[int, int]], file) -> None:
|
36 |
-
"""
|
37 |
-
Convert a single annotation's segmentation and write it to the open file handle.
|
38 |
-
"""
|
39 |
-
category_id = annotation["category_id"]
|
40 |
-
segmentation = (
|
41 |
-
annotation["segmentation"][0]
|
42 |
-
if annotation["segmentation"] and isinstance(annotation["segmentation"][0], list)
|
43 |
-
else None
|
44 |
-
)
|
45 |
-
|
46 |
-
if segmentation is None:
|
47 |
-
return
|
48 |
-
|
49 |
-
img_width, img_height = image_dims
|
50 |
-
normalized_segmentation = normalize_segmentation(segmentation, img_width, img_height)
|
51 |
-
|
52 |
-
if id_to_idx:
|
53 |
-
category_id = id_to_idx.get(category_id, category_id)
|
54 |
-
|
55 |
-
file.write(f"{category_id} {' '.join(normalized_segmentation)}\n")
|
56 |
-
|
57 |
-
|
58 |
-
def normalize_segmentation(segmentation: List[float], img_width: int, img_height: int) -> List[str]:
|
59 |
-
"""
|
60 |
-
Normalize and format segmentation coordinates.
|
61 |
-
"""
|
62 |
-
return [f"{x/img_width:.6f}" if i % 2 == 0 else f"{x/img_height:.6f}" for i, x in enumerate(segmentation)]
|
63 |
-
|
64 |
-
|
65 |
-
def convert_annotations(json_file: str, output_dir: str) -> None:
|
66 |
-
"""
|
67 |
-
Load annotation data from a JSON file and process all annotations.
|
68 |
-
"""
|
69 |
-
with open(json_file) as file:
|
70 |
-
data = json.load(file)
|
71 |
-
|
72 |
-
os.makedirs(output_dir, exist_ok=True)
|
73 |
-
|
74 |
-
image_info_dict = {img["id"]: (img["width"], img["height"]) for img in data.get("images", [])}
|
75 |
-
id_to_idx = discretize_categories(data.get("categories", [])) if "categories" in data else None
|
76 |
-
image_annotations = {img_id: [] for img_id in image_info_dict}
|
77 |
-
|
78 |
-
for annotation in data.get("annotations", []):
|
79 |
-
if not annotation.get("iscrowd", False):
|
80 |
-
image_annotations[annotation["image_id"]].append(annotation)
|
81 |
-
|
82 |
-
process_annotations(image_annotations, image_info_dict, output_dir, id_to_idx)
|
83 |
-
|
84 |
-
|
85 |
-
convert_annotations("./data/coco/annotations/instances_train2017.json", "./data/coco/labels/train2017/")
|
86 |
-
convert_annotations("./data/coco/annotations/instances_val2017.json", "./data/coco/labels/val2017/")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/data_augment.py
DELETED
@@ -1,125 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import torch
|
3 |
-
from PIL import Image
|
4 |
-
from torchvision.transforms import functional as TF
|
5 |
-
|
6 |
-
|
7 |
-
class Compose:
|
8 |
-
"""Composes several transforms together."""
|
9 |
-
|
10 |
-
def __init__(self, transforms, image_size: int = 640):
|
11 |
-
self.transforms = transforms
|
12 |
-
self.image_size = image_size
|
13 |
-
|
14 |
-
for transform in self.transforms:
|
15 |
-
if hasattr(transform, "set_parent"):
|
16 |
-
transform.set_parent(self)
|
17 |
-
|
18 |
-
def __call__(self, image, boxes):
|
19 |
-
for transform in self.transforms:
|
20 |
-
image, boxes = transform(image, boxes)
|
21 |
-
return image, boxes
|
22 |
-
|
23 |
-
|
24 |
-
class HorizontalFlip:
|
25 |
-
"""Randomly horizontally flips the image along with the bounding boxes."""
|
26 |
-
|
27 |
-
def __init__(self, prob=0.5):
|
28 |
-
self.prob = prob
|
29 |
-
|
30 |
-
def __call__(self, image, boxes):
|
31 |
-
if torch.rand(1) < self.prob:
|
32 |
-
image = TF.hflip(image)
|
33 |
-
boxes[:, [1, 3]] = 1 - boxes[:, [3, 1]]
|
34 |
-
return image, boxes
|
35 |
-
|
36 |
-
|
37 |
-
class VerticalFlip:
|
38 |
-
"""Randomly vertically flips the image along with the bounding boxes."""
|
39 |
-
|
40 |
-
def __init__(self, prob=0.5):
|
41 |
-
self.prob = prob
|
42 |
-
|
43 |
-
def __call__(self, image, boxes):
|
44 |
-
if torch.rand(1) < self.prob:
|
45 |
-
image = TF.vflip(image)
|
46 |
-
boxes[:, [2, 4]] = 1 - boxes[:, [4, 2]]
|
47 |
-
return image, boxes
|
48 |
-
|
49 |
-
|
50 |
-
class Mosaic:
|
51 |
-
"""Applies the Mosaic augmentation to a batch of images and their corresponding boxes."""
|
52 |
-
|
53 |
-
def __init__(self, prob=0.5):
|
54 |
-
self.prob = prob
|
55 |
-
self.parent = None
|
56 |
-
|
57 |
-
def set_parent(self, parent):
|
58 |
-
self.parent = parent
|
59 |
-
|
60 |
-
def __call__(self, image, boxes):
|
61 |
-
if torch.rand(1) >= self.prob:
|
62 |
-
return image, boxes
|
63 |
-
|
64 |
-
assert self.parent is not None, "Parent is not set. Mosaic cannot retrieve image size."
|
65 |
-
|
66 |
-
img_sz = self.parent.image_size # Assuming `image_size` is defined in parent
|
67 |
-
more_data = self.parent.get_more_data(3) # get 3 more images randomly
|
68 |
-
|
69 |
-
data = [(image, boxes)] + more_data
|
70 |
-
mosaic_image = Image.new("RGB", (2 * img_sz, 2 * img_sz))
|
71 |
-
vectors = np.array([(-1, -1), (0, -1), (-1, 0), (0, 0)])
|
72 |
-
center = np.array([img_sz, img_sz])
|
73 |
-
all_labels = []
|
74 |
-
|
75 |
-
for (image, boxes), vector in zip(data, vectors):
|
76 |
-
this_w, this_h = image.size
|
77 |
-
coord = tuple(center + vector * np.array([this_w, this_h]))
|
78 |
-
|
79 |
-
mosaic_image.paste(image, coord)
|
80 |
-
xmin, ymin, xmax, ymax = boxes[:, 1], boxes[:, 2], boxes[:, 3], boxes[:, 4]
|
81 |
-
xmin = (xmin * this_w + coord[0]) / (2 * img_sz)
|
82 |
-
xmax = (xmax * this_w + coord[0]) / (2 * img_sz)
|
83 |
-
ymin = (ymin * this_h + coord[1]) / (2 * img_sz)
|
84 |
-
ymax = (ymax * this_h + coord[1]) / (2 * img_sz)
|
85 |
-
|
86 |
-
adjusted_boxes = torch.stack([boxes[:, 0], xmin, ymin, xmax, ymax], dim=1)
|
87 |
-
all_labels.append(adjusted_boxes)
|
88 |
-
|
89 |
-
all_labels = torch.cat(all_labels, dim=0)
|
90 |
-
mosaic_image = mosaic_image.resize((img_sz, img_sz))
|
91 |
-
return mosaic_image, all_labels
|
92 |
-
|
93 |
-
|
94 |
-
class MixUp:
|
95 |
-
"""Applies the MixUp augmentation to a pair of images and their corresponding boxes."""
|
96 |
-
|
97 |
-
def __init__(self, prob=0.5, alpha=1.0):
|
98 |
-
self.alpha = alpha
|
99 |
-
self.prob = prob
|
100 |
-
self.parent = None
|
101 |
-
|
102 |
-
def set_parent(self, parent):
|
103 |
-
"""Set the parent dataset object for accessing dataset methods."""
|
104 |
-
self.parent = parent
|
105 |
-
|
106 |
-
def __call__(self, image, boxes):
|
107 |
-
if torch.rand(1) >= self.prob:
|
108 |
-
return image, boxes
|
109 |
-
|
110 |
-
assert self.parent is not None, "Parent is not set. MixUp cannot retrieve additional data."
|
111 |
-
|
112 |
-
# Retrieve another image and its boxes randomly from the dataset
|
113 |
-
image2, boxes2 = self.parent.get_more_data()[0]
|
114 |
-
|
115 |
-
# Calculate the mixup lambda parameter
|
116 |
-
lam = np.random.beta(self.alpha, self.alpha) if self.alpha > 0 else 0.5
|
117 |
-
|
118 |
-
# Mix images
|
119 |
-
image1, image2 = TF.to_tensor(image), TF.to_tensor(image2)
|
120 |
-
mixed_image = lam * image1 + (1 - lam) * image2
|
121 |
-
|
122 |
-
# Mix bounding boxes
|
123 |
-
mixed_boxes = torch.cat([lam * boxes, (1 - lam) * boxes2])
|
124 |
-
|
125 |
-
return TF.to_pil_image(mixed_image), mixed_boxes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/dataloader.py
DELETED
@@ -1,206 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from os import path
|
3 |
-
from typing import List, Tuple, Union
|
4 |
-
|
5 |
-
import diskcache as dc
|
6 |
-
import hydra
|
7 |
-
import numpy as np
|
8 |
-
import torch
|
9 |
-
from loguru import logger
|
10 |
-
from PIL import Image
|
11 |
-
from torch.utils.data import DataLoader, Dataset
|
12 |
-
from torchvision.transforms import functional as TF
|
13 |
-
from tqdm.rich import tqdm
|
14 |
-
|
15 |
-
from tools.dataset_helper import (
|
16 |
-
create_image_info_dict,
|
17 |
-
find_labels_path,
|
18 |
-
get_scaled_segmentation,
|
19 |
-
)
|
20 |
-
from utils.data_augment import Compose, HorizontalFlip, MixUp, Mosaic, VerticalFlip
|
21 |
-
from utils.drawer import draw_bboxes
|
22 |
-
|
23 |
-
|
24 |
-
class YoloDataset(Dataset):
|
25 |
-
def __init__(self, config: dict, phase: str = "train2017", image_size: int = 640):
|
26 |
-
dataset_cfg = config.data
|
27 |
-
augment_cfg = config.augmentation
|
28 |
-
phase_name = dataset_cfg.get(phase, phase)
|
29 |
-
self.image_size = image_size
|
30 |
-
|
31 |
-
transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
|
32 |
-
self.transform = Compose(transforms, self.image_size)
|
33 |
-
self.transform.get_more_data = self.get_more_data
|
34 |
-
self.data = self.load_data(dataset_cfg.path, phase_name)
|
35 |
-
|
36 |
-
def load_data(self, dataset_path, phase_name):
|
37 |
-
"""
|
38 |
-
Loads data from a cache or generates a new cache for a specific dataset phase.
|
39 |
-
|
40 |
-
Parameters:
|
41 |
-
dataset_path (str): The root path to the dataset directory.
|
42 |
-
phase_name (str): The specific phase of the dataset (e.g., 'train', 'test') to load or generate data for.
|
43 |
-
|
44 |
-
Returns:
|
45 |
-
dict: The loaded data from the cache for the specified phase.
|
46 |
-
"""
|
47 |
-
cache_path = path.join(dataset_path, ".cache")
|
48 |
-
cache = dc.Cache(cache_path)
|
49 |
-
data = cache.get(phase_name)
|
50 |
-
|
51 |
-
if data is None:
|
52 |
-
logger.info("Generating {} cache", phase_name)
|
53 |
-
data = self.filter_data(dataset_path, phase_name)
|
54 |
-
cache[phase_name] = data
|
55 |
-
|
56 |
-
cache.close()
|
57 |
-
logger.info("π¦ Loaded {} cache", phase_name)
|
58 |
-
data = cache[phase_name]
|
59 |
-
return data
|
60 |
-
|
61 |
-
def filter_data(self, dataset_path: str, phase_name: str) -> list:
|
62 |
-
"""
|
63 |
-
Filters and collects dataset information by pairing images with their corresponding labels.
|
64 |
-
|
65 |
-
Parameters:
|
66 |
-
images_path (str): Path to the directory containing image files.
|
67 |
-
labels_path (str): Path to the directory containing label files.
|
68 |
-
|
69 |
-
Returns:
|
70 |
-
list: A list of tuples, each containing the path to an image file and its associated segmentation as a tensor.
|
71 |
-
"""
|
72 |
-
images_path = path.join(dataset_path, "images", phase_name)
|
73 |
-
labels_path, data_type = find_labels_path(dataset_path, phase_name)
|
74 |
-
images_list = sorted(os.listdir(images_path))
|
75 |
-
if data_type == "json":
|
76 |
-
annotations_index, image_info_dict = create_image_info_dict(labels_path)
|
77 |
-
|
78 |
-
data = []
|
79 |
-
valid_inputs = 0
|
80 |
-
for image_name in tqdm(images_list, desc="Filtering data"):
|
81 |
-
if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
|
82 |
-
continue
|
83 |
-
image_id, _ = path.splitext(image_name)
|
84 |
-
|
85 |
-
if data_type == "json":
|
86 |
-
image_info = image_info_dict.get(image_id, None)
|
87 |
-
if image_info is None:
|
88 |
-
continue
|
89 |
-
annotations = annotations_index.get(image_info["id"], [])
|
90 |
-
image_seg_annotations = get_scaled_segmentation(annotations, image_info)
|
91 |
-
if not image_seg_annotations:
|
92 |
-
continue
|
93 |
-
|
94 |
-
elif data_type == "txt":
|
95 |
-
label_path = path.join(labels_path, f"{image_id}.txt")
|
96 |
-
if not path.isfile(label_path):
|
97 |
-
continue
|
98 |
-
with open(label_path, "r") as file:
|
99 |
-
image_seg_annotations = [list(map(float, line.strip().split())) for line in file]
|
100 |
-
|
101 |
-
labels = self.load_valid_labels(image_id, image_seg_annotations)
|
102 |
-
if labels is not None:
|
103 |
-
img_path = path.join(images_path, image_name)
|
104 |
-
data.append((img_path, labels))
|
105 |
-
valid_inputs += 1
|
106 |
-
|
107 |
-
logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
|
108 |
-
return data
|
109 |
-
|
110 |
-
def load_valid_labels(self, label_path, seg_data_one_img) -> Union[torch.Tensor, None]:
|
111 |
-
"""
|
112 |
-
Loads and validates bounding box data is [0, 1] from a label file.
|
113 |
-
|
114 |
-
Parameters:
|
115 |
-
label_path (str): The filepath to the label file containing bounding box data.
|
116 |
-
|
117 |
-
Returns:
|
118 |
-
torch.Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None.
|
119 |
-
"""
|
120 |
-
bboxes = []
|
121 |
-
for seg_data in seg_data_one_img:
|
122 |
-
cls = seg_data[0]
|
123 |
-
points = np.array(seg_data[1:]).reshape(-1, 2)
|
124 |
-
valid_points = points[(points >= 0) & (points <= 1)].reshape(-1, 2)
|
125 |
-
if valid_points.size > 1:
|
126 |
-
bbox = torch.tensor([cls, *valid_points.min(axis=0), *valid_points.max(axis=0)])
|
127 |
-
bboxes.append(bbox)
|
128 |
-
|
129 |
-
if bboxes:
|
130 |
-
return torch.stack(bboxes)
|
131 |
-
else:
|
132 |
-
logger.warning("No valid BBox in {}", label_path)
|
133 |
-
return None
|
134 |
-
|
135 |
-
def get_data(self, idx):
|
136 |
-
img_path, bboxes = self.data[idx]
|
137 |
-
img = Image.open(img_path).convert("RGB")
|
138 |
-
return img, bboxes
|
139 |
-
|
140 |
-
def get_more_data(self, num: int = 1):
|
141 |
-
indices = torch.randint(0, len(self), (num,))
|
142 |
-
return [self.get_data(idx) for idx in indices]
|
143 |
-
|
144 |
-
def __getitem__(self, idx) -> Union[Image.Image, torch.Tensor]:
|
145 |
-
img, bboxes = self.get_data(idx)
|
146 |
-
if self.transform:
|
147 |
-
img, bboxes = self.transform(img, bboxes)
|
148 |
-
img = TF.to_tensor(img)
|
149 |
-
return img, bboxes
|
150 |
-
|
151 |
-
def __len__(self) -> int:
|
152 |
-
return len(self.data)
|
153 |
-
|
154 |
-
|
155 |
-
class YoloDataLoader(DataLoader):
|
156 |
-
def __init__(self, config: dict):
|
157 |
-
"""Initializes the YoloDataLoader with hydra-config files."""
|
158 |
-
hyper = config.hyper.data
|
159 |
-
dataset = YoloDataset(config)
|
160 |
-
|
161 |
-
super().__init__(
|
162 |
-
dataset,
|
163 |
-
batch_size=hyper.batch_size,
|
164 |
-
shuffle=hyper.shuffle,
|
165 |
-
num_workers=hyper.num_workers,
|
166 |
-
pin_memory=hyper.pin_memory,
|
167 |
-
collate_fn=self.collate_fn,
|
168 |
-
)
|
169 |
-
|
170 |
-
def collate_fn(self, batch: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
171 |
-
"""
|
172 |
-
A collate function to handle batching of images and their corresponding targets.
|
173 |
-
|
174 |
-
Args:
|
175 |
-
batch (list of tuples): Each tuple contains:
|
176 |
-
- image (torch.Tensor): The image tensor.
|
177 |
-
- labels (torch.Tensor): The tensor of labels for the image.
|
178 |
-
|
179 |
-
Returns:
|
180 |
-
Tuple[torch.Tensor, List[torch.Tensor]]: A tuple containing:
|
181 |
-
- A tensor of batched images.
|
182 |
-
- A list of tensors, each corresponding to bboxes for each image in the batch.
|
183 |
-
"""
|
184 |
-
images = torch.stack([item[0] for item in batch])
|
185 |
-
targets = [item[1] for item in batch]
|
186 |
-
return images, targets
|
187 |
-
|
188 |
-
|
189 |
-
def get_dataloader(config):
|
190 |
-
return YoloDataLoader(config)
|
191 |
-
|
192 |
-
|
193 |
-
@hydra.main(config_path="../config", config_name="config", version_base=None)
|
194 |
-
def main(cfg):
|
195 |
-
dataloader = get_dataloader(cfg)
|
196 |
-
draw_bboxes(*next(iter(dataloader)))
|
197 |
-
|
198 |
-
|
199 |
-
if __name__ == "__main__":
|
200 |
-
import sys
|
201 |
-
|
202 |
-
sys.path.append("./")
|
203 |
-
from tools.log_helper import custom_logger
|
204 |
-
|
205 |
-
custom_logger()
|
206 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/drawer.py
DELETED
@@ -1,41 +0,0 @@
|
|
1 |
-
from typing import List, Union
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from loguru import logger
|
5 |
-
from PIL import Image, ImageDraw, ImageFont
|
6 |
-
from torchvision.transforms.functional import to_pil_image
|
7 |
-
|
8 |
-
|
9 |
-
def draw_bboxes(img: Union[Image.Image, torch.Tensor], bboxes: List[List[Union[int, float]]]):
|
10 |
-
"""
|
11 |
-
Draw bounding boxes on an image.
|
12 |
-
|
13 |
-
Args:
|
14 |
-
- img (PIL Image or torch.Tensor): Image on which to draw the bounding boxes.
|
15 |
-
- bboxes (List of Lists/Tensors): Bounding boxes with [class_id, x_min, y_min, x_max, y_max],
|
16 |
-
where coordinates are normalized [0, 1].
|
17 |
-
"""
|
18 |
-
# Convert tensor image to PIL Image if necessary
|
19 |
-
if isinstance(img, torch.Tensor):
|
20 |
-
if img.dim() > 3:
|
21 |
-
logger.info("Multi-frame tensor detected, using the first image.")
|
22 |
-
img = img[0]
|
23 |
-
bboxes = bboxes[0]
|
24 |
-
img = to_pil_image(img)
|
25 |
-
|
26 |
-
draw = ImageDraw.Draw(img)
|
27 |
-
width, height = img.size
|
28 |
-
font = ImageFont.load_default(30)
|
29 |
-
|
30 |
-
for bbox in bboxes:
|
31 |
-
class_id, x_min, y_min, x_max, y_max = bbox
|
32 |
-
x_min = x_min * width
|
33 |
-
x_max = x_max * width
|
34 |
-
y_min = y_min * height
|
35 |
-
y_max = y_max * height
|
36 |
-
shape = [(x_min, y_min), (x_max, y_max)]
|
37 |
-
draw.rectangle(shape, outline="red", width=3)
|
38 |
-
draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
|
39 |
-
|
40 |
-
img.save("visualize.jpg") # Save the image with annotations
|
41 |
-
logger.info("Saved visualize image at visualize.png")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/get_dataset.py
DELETED
@@ -1,84 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import zipfile
|
3 |
-
|
4 |
-
import requests
|
5 |
-
from hydra import main
|
6 |
-
from loguru import logger
|
7 |
-
from tqdm import tqdm
|
8 |
-
|
9 |
-
|
10 |
-
def download_file(url, destination):
|
11 |
-
"""
|
12 |
-
Downloads a file from the specified URL to the destination path with progress logging.
|
13 |
-
"""
|
14 |
-
logger.info(f"Downloading {os.path.basename(destination)}...")
|
15 |
-
with requests.get(url, stream=True) as response:
|
16 |
-
response.raise_for_status()
|
17 |
-
total_size = int(response.headers.get("content-length", 0))
|
18 |
-
progress = tqdm(total=total_size, unit="iB", unit_scale=True, desc=os.path.basename(destination), leave=True)
|
19 |
-
|
20 |
-
with open(destination, "wb") as file:
|
21 |
-
for data in response.iter_content(chunk_size=1024 * 1024): # 1 MB chunks
|
22 |
-
file.write(data)
|
23 |
-
progress.update(len(data))
|
24 |
-
progress.close()
|
25 |
-
logger.info("Download completed.")
|
26 |
-
|
27 |
-
|
28 |
-
def unzip_file(source, destination):
|
29 |
-
"""
|
30 |
-
Extracts a ZIP file to the specified directory and removes the ZIP file after extraction.
|
31 |
-
"""
|
32 |
-
logger.info(f"Unzipping {os.path.basename(source)}...")
|
33 |
-
with zipfile.ZipFile(source, "r") as zip_ref:
|
34 |
-
zip_ref.extractall(destination)
|
35 |
-
os.remove(source)
|
36 |
-
logger.info(f"Removed {source}.")
|
37 |
-
|
38 |
-
|
39 |
-
def check_files(directory, expected_count=None):
|
40 |
-
"""
|
41 |
-
Returns True if the number of files in the directory matches expected_count, False otherwise.
|
42 |
-
"""
|
43 |
-
files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
|
44 |
-
return len(files) == expected_count if expected_count is not None else bool(files)
|
45 |
-
|
46 |
-
|
47 |
-
@main(config_path="../config/data", config_name="download", version_base=None)
|
48 |
-
def prepare_dataset(cfg):
|
49 |
-
"""
|
50 |
-
Prepares dataset by downloading and unzipping if necessary.
|
51 |
-
"""
|
52 |
-
data_dir = cfg.save_path
|
53 |
-
for data_type, settings in cfg.datasets.items():
|
54 |
-
base_url = settings["base_url"]
|
55 |
-
for dataset_type, dataset_args in settings.items():
|
56 |
-
if dataset_type == "base_url":
|
57 |
-
continue # Skip the base_url entry
|
58 |
-
file_name = f"{dataset_args.get('file_name', dataset_type)}.zip"
|
59 |
-
url = f"{base_url}{file_name}"
|
60 |
-
local_zip_path = os.path.join(data_dir, file_name)
|
61 |
-
extract_to = os.path.join(data_dir, data_type) if data_type != "annotations" else data_dir
|
62 |
-
final_place = os.path.join(extract_to, dataset_type)
|
63 |
-
|
64 |
-
os.makedirs(extract_to, exist_ok=True)
|
65 |
-
if check_files(final_place, dataset_args.get("file_num")):
|
66 |
-
logger.info(f"Dataset {dataset_type} already verified.")
|
67 |
-
continue
|
68 |
-
|
69 |
-
if not os.path.exists(local_zip_path):
|
70 |
-
download_file(url, local_zip_path)
|
71 |
-
unzip_file(local_zip_path, extract_to)
|
72 |
-
|
73 |
-
if not check_files(final_place, dataset_args.get("file_num")):
|
74 |
-
logger.error(f"Error verifying the {dataset_type} dataset after extraction.")
|
75 |
-
|
76 |
-
|
77 |
-
if __name__ == "__main__":
|
78 |
-
import sys
|
79 |
-
|
80 |
-
sys.path.append("./")
|
81 |
-
from tools.log_helper import custom_logger
|
82 |
-
|
83 |
-
custom_logger()
|
84 |
-
prepare_dataset()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/loss.py
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
def get_loss_function(*args, **kwargs):
|
2 |
-
raise NotImplementedError
|
|
|
|
|
|
yolo/config/config.py
CHANGED
@@ -2,9 +2,15 @@ from dataclasses import dataclass
|
|
2 |
from typing import Dict, List, Union
|
3 |
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
@dataclass
|
6 |
class Model:
|
7 |
-
anchor:
|
8 |
model: Dict[str, List[Dict[str, Union[Dict, List, int]]]]
|
9 |
|
10 |
|
@@ -20,6 +26,8 @@ class DataLoaderConfig:
|
|
20 |
shuffle: bool
|
21 |
num_workers: int
|
22 |
pin_memory: bool
|
|
|
|
|
23 |
|
24 |
|
25 |
@dataclass
|
@@ -52,11 +60,19 @@ class EMAConfig:
|
|
52 |
decay: float
|
53 |
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
@dataclass
|
56 |
class TrainConfig:
|
57 |
optimizer: OptimizerConfig
|
58 |
scheduler: SchedulerConfig
|
59 |
ema: EMAConfig
|
|
|
60 |
|
61 |
|
62 |
@dataclass
|
|
|
2 |
from typing import Dict, List, Union
|
3 |
|
4 |
|
5 |
+
@dataclass
|
6 |
+
class AnchorConfig:
|
7 |
+
reg_max: int
|
8 |
+
strides: List[int]
|
9 |
+
|
10 |
+
|
11 |
@dataclass
|
12 |
class Model:
|
13 |
+
anchor: AnchorConfig
|
14 |
model: Dict[str, List[Dict[str, Union[Dict, List, int]]]]
|
15 |
|
16 |
|
|
|
26 |
shuffle: bool
|
27 |
num_workers: int
|
28 |
pin_memory: bool
|
29 |
+
image_size: List[int]
|
30 |
+
class_num: int
|
31 |
|
32 |
|
33 |
@dataclass
|
|
|
60 |
decay: float
|
61 |
|
62 |
|
63 |
+
@dataclass
|
64 |
+
class MatcherConfig:
|
65 |
+
iou: str
|
66 |
+
topk: int
|
67 |
+
factor: Dict[str, int]
|
68 |
+
|
69 |
+
|
70 |
@dataclass
|
71 |
class TrainConfig:
|
72 |
optimizer: OptimizerConfig
|
73 |
scheduler: SchedulerConfig
|
74 |
ema: EMAConfig
|
75 |
+
matcher: MatcherConfig
|
76 |
|
77 |
|
78 |
@dataclass
|
yolo/config/hyper/default.yaml
CHANGED
@@ -3,12 +3,28 @@ data:
|
|
3 |
shuffle: True
|
4 |
num_workers: 4
|
5 |
pin_memory: True
|
|
|
|
|
6 |
train:
|
7 |
optimizer:
|
8 |
type: Adam
|
9 |
args:
|
10 |
lr: 0.001
|
11 |
weight_decay: 0.0001
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
scheduler:
|
13 |
type: StepLR
|
14 |
args:
|
|
|
3 |
shuffle: True
|
4 |
num_workers: 4
|
5 |
pin_memory: True
|
6 |
+
class_num: 80
|
7 |
+
image_size: [640, 640]
|
8 |
train:
|
9 |
optimizer:
|
10 |
type: Adam
|
11 |
args:
|
12 |
lr: 0.001
|
13 |
weight_decay: 0.0001
|
14 |
+
loss:
|
15 |
+
BCELoss:
|
16 |
+
args:
|
17 |
+
BoxLoss:
|
18 |
+
args:
|
19 |
+
alpha: 0.1
|
20 |
+
DFLoss:
|
21 |
+
args:
|
22 |
+
matcher:
|
23 |
+
iou: CIoU
|
24 |
+
topk: 10
|
25 |
+
factor:
|
26 |
+
iou: 6.0
|
27 |
+
cls: 0.5
|
28 |
scheduler:
|
29 |
type: StepLR
|
30 |
args:
|
yolo/config/model/v7-base.yaml
CHANGED
@@ -1,5 +1,9 @@
|
|
1 |
nc: 80
|
2 |
|
|
|
|
|
|
|
|
|
3 |
model:
|
4 |
backbone:
|
5 |
- Conv:
|
|
|
1 |
nc: 80
|
2 |
|
3 |
+
anchor:
|
4 |
+
reg_max: 16
|
5 |
+
strides: [8, 16, 32]
|
6 |
+
|
7 |
model:
|
8 |
backbone:
|
9 |
- Conv:
|
yolo/tools/bbox_helper.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import List, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import Tensor
|
7 |
+
|
8 |
+
from yolo.config.config import MatcherConfig
|
9 |
+
|
10 |
+
|
11 |
+
def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
|
12 |
+
metrics = metrics.lower()
|
13 |
+
EPS = 1e-9
|
14 |
+
dtype = bbox1.dtype
|
15 |
+
bbox1 = bbox1.to(torch.float32)
|
16 |
+
bbox2 = bbox2.to(torch.float32)
|
17 |
+
|
18 |
+
# Expand dimensions if necessary
|
19 |
+
if bbox1.ndim == 2 and bbox2.ndim == 2:
|
20 |
+
bbox1 = bbox1.unsqueeze(1) # (Ax4) -> (Ax1x4)
|
21 |
+
bbox2 = bbox2.unsqueeze(0) # (Bx4) -> (1xBx4)
|
22 |
+
elif bbox1.ndim == 3 and bbox2.ndim == 3:
|
23 |
+
bbox1 = bbox1.unsqueeze(2) # (BZxAx4) -> (BZxAx1x4)
|
24 |
+
bbox2 = bbox2.unsqueeze(1) # (BZxBx4) -> (BZx1xBx4)
|
25 |
+
|
26 |
+
# Calculate intersection coordinates
|
27 |
+
xmin_inter = torch.max(bbox1[..., 0], bbox2[..., 0])
|
28 |
+
ymin_inter = torch.max(bbox1[..., 1], bbox2[..., 1])
|
29 |
+
xmax_inter = torch.min(bbox1[..., 2], bbox2[..., 2])
|
30 |
+
ymax_inter = torch.min(bbox1[..., 3], bbox2[..., 3])
|
31 |
+
|
32 |
+
# Calculate intersection area
|
33 |
+
intersection_area = torch.clamp(xmax_inter - xmin_inter, min=0) * torch.clamp(ymax_inter - ymin_inter, min=0)
|
34 |
+
|
35 |
+
# Calculate area of each bbox
|
36 |
+
area_bbox1 = (bbox1[..., 2] - bbox1[..., 0]) * (bbox1[..., 3] - bbox1[..., 1])
|
37 |
+
area_bbox2 = (bbox2[..., 2] - bbox2[..., 0]) * (bbox2[..., 3] - bbox2[..., 1])
|
38 |
+
|
39 |
+
# Calculate union area
|
40 |
+
union_area = area_bbox1 + area_bbox2 - intersection_area
|
41 |
+
|
42 |
+
# Calculate IoU
|
43 |
+
iou = intersection_area / (union_area + EPS)
|
44 |
+
if metrics == "iou":
|
45 |
+
return iou
|
46 |
+
|
47 |
+
# Calculate centroid distance
|
48 |
+
cx1 = (bbox1[..., 2] + bbox1[..., 0]) / 2
|
49 |
+
cy1 = (bbox1[..., 3] + bbox1[..., 1]) / 2
|
50 |
+
cx2 = (bbox2[..., 2] + bbox2[..., 0]) / 2
|
51 |
+
cy2 = (bbox2[..., 3] + bbox2[..., 1]) / 2
|
52 |
+
cent_dis = (cx1 - cx2) ** 2 + (cy1 - cy2) ** 2
|
53 |
+
|
54 |
+
# Calculate diagonal length of the smallest enclosing box
|
55 |
+
c_x = torch.max(bbox1[..., 2], bbox2[..., 2]) - torch.min(bbox1[..., 0], bbox2[..., 0])
|
56 |
+
c_y = torch.max(bbox1[..., 3], bbox2[..., 3]) - torch.min(bbox1[..., 1], bbox2[..., 1])
|
57 |
+
diag_dis = c_x**2 + c_y**2 + EPS
|
58 |
+
|
59 |
+
diou = iou - (cent_dis / diag_dis)
|
60 |
+
if metrics == "diou":
|
61 |
+
return diou
|
62 |
+
|
63 |
+
# Compute aspect ratio penalty term
|
64 |
+
arctan = torch.atan((bbox1[..., 2] - bbox1[..., 0]) / (bbox1[..., 3] - bbox1[..., 1] + EPS)) - torch.atan(
|
65 |
+
(bbox2[..., 2] - bbox2[..., 0]) / (bbox2[..., 3] - bbox2[..., 1] + EPS)
|
66 |
+
)
|
67 |
+
v = (4 / (math.pi**2)) * (arctan**2)
|
68 |
+
alpha = v / (v - iou + 1 + EPS)
|
69 |
+
# Compute CIoU
|
70 |
+
ciou = diou - alpha * v
|
71 |
+
return ciou.to(dtype)
|
72 |
+
|
73 |
+
|
74 |
+
def transform_bbox(bbox: Tensor, indicator="xywh -> xyxy"):
|
75 |
+
data_type = bbox.dtype
|
76 |
+
in_type, out_type = indicator.replace(" ", "").split("->")
|
77 |
+
|
78 |
+
if in_type not in ["xyxy", "xywh", "xycwh"] or out_type not in ["xyxy", "xywh", "xycwh"]:
|
79 |
+
raise ValueError("Invalid input or output format")
|
80 |
+
|
81 |
+
if in_type == "xywh":
|
82 |
+
x_min = bbox[..., 0]
|
83 |
+
y_min = bbox[..., 1]
|
84 |
+
x_max = bbox[..., 0] + bbox[..., 2]
|
85 |
+
y_max = bbox[..., 1] + bbox[..., 3]
|
86 |
+
elif in_type == "xyxy":
|
87 |
+
x_min = bbox[..., 0]
|
88 |
+
y_min = bbox[..., 1]
|
89 |
+
x_max = bbox[..., 2]
|
90 |
+
y_max = bbox[..., 3]
|
91 |
+
elif in_type == "xycwh":
|
92 |
+
x_min = bbox[..., 0] - bbox[..., 2] / 2
|
93 |
+
y_min = bbox[..., 1] - bbox[..., 3] / 2
|
94 |
+
x_max = bbox[..., 0] + bbox[..., 2] / 2
|
95 |
+
y_max = bbox[..., 1] + bbox[..., 3] / 2
|
96 |
+
|
97 |
+
if out_type == "xywh":
|
98 |
+
bbox = torch.stack([x_min, y_min, x_max - x_min, y_max - y_min], dim=-1)
|
99 |
+
elif out_type == "xyxy":
|
100 |
+
bbox = torch.stack([x_min, y_min, x_max, y_max], dim=-1)
|
101 |
+
elif out_type == "xycwh":
|
102 |
+
bbox = torch.stack([(x_min + x_max) / 2, (y_min + y_max) / 2, x_max - x_min, y_max - y_min], dim=-1)
|
103 |
+
|
104 |
+
return bbox.to(dtype=data_type)
|
105 |
+
|
106 |
+
|
107 |
+
def make_anchor(image_size: List[int], strides: List[int], device):
|
108 |
+
W, H = image_size
|
109 |
+
anchors = []
|
110 |
+
scaler = []
|
111 |
+
for stride in strides:
|
112 |
+
anchor_num = W // stride * H // stride
|
113 |
+
scaler.append(torch.full((anchor_num,), stride, device=device))
|
114 |
+
shift = stride // 2
|
115 |
+
x = torch.arange(0, W, stride, device=device) + shift
|
116 |
+
y = torch.arange(0, H, stride, device=device) + shift
|
117 |
+
anchor_x, anchor_y = torch.meshgrid(x, y, indexing="ij")
|
118 |
+
anchor = torch.stack([anchor_y.flatten(), anchor_x.flatten()], dim=-1)
|
119 |
+
anchors.append(anchor)
|
120 |
+
all_anchors = torch.cat(anchors, dim=0)
|
121 |
+
all_scalers = torch.cat(scaler, dim=0)
|
122 |
+
return all_anchors, all_scalers
|
123 |
+
|
124 |
+
|
125 |
+
class BoxMatcher:
|
126 |
+
def __init__(self, cfg: MatcherConfig, class_num: int, anchors: Tensor) -> None:
|
127 |
+
self.class_num = class_num
|
128 |
+
self.anchors = anchors
|
129 |
+
for attr_name in cfg:
|
130 |
+
setattr(self, attr_name, cfg[attr_name])
|
131 |
+
|
132 |
+
def get_valid_matrix(self, target_bbox: Tensor):
|
133 |
+
"""
|
134 |
+
Get a boolean mask that indicates whether each target bounding box overlaps with each anchor.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
target_bbox [batch x targets x 4]: The bounding box of each targets.
|
138 |
+
Returns:
|
139 |
+
[batch x targets x anchors]: A boolean tensor indicates if target bounding box overlaps with anchors.
|
140 |
+
"""
|
141 |
+
Xmin, Ymin, Xmax, Ymax = target_bbox[:, :, None].unbind(3)
|
142 |
+
anchors = self.anchors[None, None] # add a axis at first, second dimension
|
143 |
+
anchors_x, anchors_y = anchors.unbind(dim=3)
|
144 |
+
target_in_x = (Xmin < anchors_x) & (anchors_x < Xmax)
|
145 |
+
target_in_y = (Ymin < anchors_y) & (anchors_y < Ymax)
|
146 |
+
target_on_anchor = target_in_x & target_in_y
|
147 |
+
return target_on_anchor
|
148 |
+
|
149 |
+
def get_cls_matrix(self, predict_cls: Tensor, target_cls: Tensor) -> Tensor:
|
150 |
+
"""
|
151 |
+
Get the (predicted class' probabilities) corresponding to the target classes across all anchors
|
152 |
+
|
153 |
+
Args:
|
154 |
+
predict_cls [batch x class x anchors]: The predicted probabilities for each class across each anchor.
|
155 |
+
target_cls [batch x targets]: The class index for each target.
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
[batch x targets x anchors]: The probabilities from `pred_cls` corresponding to the class indices specified in `target_cls`.
|
159 |
+
"""
|
160 |
+
target_cls = target_cls.expand(-1, -1, 8400)
|
161 |
+
predict_cls = predict_cls.transpose(1, 2)
|
162 |
+
cls_probabilities = torch.gather(predict_cls, 1, target_cls)
|
163 |
+
return cls_probabilities
|
164 |
+
|
165 |
+
def get_iou_matrix(self, predict_bbox, target_bbox) -> Tensor:
|
166 |
+
"""
|
167 |
+
Get the IoU between each target bounding box and each predicted bounding box.
|
168 |
+
|
169 |
+
Args:
|
170 |
+
predict_bbox [batch x predicts x 4]: Bounding box with [x1, y1, x2, y2].
|
171 |
+
target_bbox [batch x targets x 4]: Bounding box with [x1, y1, x2, y2].
|
172 |
+
Returns:
|
173 |
+
[batch x targets x predicts]: The IoU scores between each target and predicted.
|
174 |
+
"""
|
175 |
+
return calculate_iou(target_bbox, predict_bbox, self.iou).clamp(0, 1)
|
176 |
+
|
177 |
+
def filter_topk(self, target_matrix: Tensor, topk: int = 10) -> Tuple[Tensor, Tensor]:
|
178 |
+
"""
|
179 |
+
Filter the top-k suitability of targets for each anchor.
|
180 |
+
|
181 |
+
Args:
|
182 |
+
target_matrix [batch x targets x anchors]: The suitability for each targets-anchors
|
183 |
+
topk (int, optional): Number of top scores to retain per anchor.
|
184 |
+
|
185 |
+
Returns:
|
186 |
+
topk_targets [batch x targets x anchors]: Only leave the topk targets for each anchor
|
187 |
+
topk_masks [batch x targets x anchors]: A boolean mask indicating the top-k scores' positions.
|
188 |
+
"""
|
189 |
+
values, indices = target_matrix.topk(topk, dim=-1)
|
190 |
+
topk_targets = torch.zeros_like(target_matrix, device=target_matrix.device)
|
191 |
+
topk_targets.scatter_(dim=-1, index=indices, src=values)
|
192 |
+
topk_masks = topk_targets > 0
|
193 |
+
return topk_targets, topk_masks
|
194 |
+
|
195 |
+
def filter_duplicates(self, target_matrix: Tensor):
|
196 |
+
"""
|
197 |
+
Filter the maximum suitability target index of each anchor.
|
198 |
+
|
199 |
+
Args:
|
200 |
+
target_matrix [batch x targets x anchors]: The suitability for each targets-anchors
|
201 |
+
|
202 |
+
Returns:
|
203 |
+
unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
|
204 |
+
"""
|
205 |
+
unique_indices = target_matrix.argmax(dim=1)
|
206 |
+
return unique_indices[..., None]
|
207 |
+
|
208 |
+
def __call__(self, target: Tensor, predict: Tensor) -> Tuple[Tensor, Tensor]:
|
209 |
+
"""
|
210 |
+
1. For each anchor prediction, find the highest suitability targets
|
211 |
+
2. Select the targets
|
212 |
+
2. Noramlize the class probilities of targets
|
213 |
+
"""
|
214 |
+
predict_cls, predict_bbox = predict.split(self.class_num, dim=-1) # B, HW x (C B) -> B x HW x C, B x HW x B
|
215 |
+
target_cls, target_bbox = target.split([1, 4], dim=-1) # B x N x (C B) -> B x N x C, B x N x B
|
216 |
+
target_cls = target_cls.long()
|
217 |
+
|
218 |
+
# get valid matrix (each gt appear in which anchor grid)
|
219 |
+
grid_mask = self.get_valid_matrix(target_bbox)
|
220 |
+
|
221 |
+
# get iou matrix (iou with each gt bbox and each predict anchor)
|
222 |
+
iou_mat = self.get_iou_matrix(predict_bbox, target_bbox)
|
223 |
+
|
224 |
+
# get cls matrix (cls prob with each gt class and each predict class)
|
225 |
+
cls_mat = self.get_cls_matrix(predict_cls.sigmoid(), target_cls)
|
226 |
+
|
227 |
+
# TODO: alpha and beta should be set at hydra
|
228 |
+
target_matrix = grid_mask * (iou_mat ** self.factor["iou"]) * (cls_mat ** self.factor["cls"])
|
229 |
+
|
230 |
+
# choose topk
|
231 |
+
# TODO: topk should be set at hydra
|
232 |
+
topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
|
233 |
+
|
234 |
+
# delete one anchor pred assign to mutliple gts
|
235 |
+
unique_indices = self.filter_duplicates(topk_targets)
|
236 |
+
|
237 |
+
# TODO: do we need grid_mask? Filter the valid groud truth
|
238 |
+
valid_mask = (grid_mask.sum(dim=-2) * topk_mask.sum(dim=-2)).bool()
|
239 |
+
|
240 |
+
align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
|
241 |
+
align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
|
242 |
+
align_cls = F.one_hot(align_cls, self.class_num)
|
243 |
+
|
244 |
+
# normalize class ditribution
|
245 |
+
max_target = target_matrix.amax(dim=-1, keepdim=True)
|
246 |
+
max_iou = iou_mat.amax(dim=-1, keepdim=True)
|
247 |
+
normalize_term = (target_matrix / (max_target + 1e-9)) * max_iou
|
248 |
+
normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices)
|
249 |
+
align_cls = align_cls * normalize_term * valid_mask[:, :, None]
|
250 |
+
|
251 |
+
return torch.cat([align_cls, align_bbox], dim=-1), valid_mask.bool()
|
yolo/utils/loss.py
CHANGED
@@ -1,2 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
def get_loss_function(*args, **kwargs):
|
2 |
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from typing import Any, List, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from einops import rearrange
|
7 |
+
from loguru import logger
|
8 |
+
from torch import Tensor, nn
|
9 |
+
from torch.nn import BCEWithLogitsLoss
|
10 |
+
|
11 |
+
from yolo.config.config import Config
|
12 |
+
from yolo.tools.bbox_helper import (
|
13 |
+
BoxMatcher,
|
14 |
+
calculate_iou,
|
15 |
+
make_anchor,
|
16 |
+
transform_bbox,
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
def get_loss_function(*args, **kwargs):
|
21 |
raise NotImplementedError
|
22 |
+
|
23 |
+
|
24 |
+
class BCELoss(nn.Module):
|
25 |
+
def __init__(self) -> None:
|
26 |
+
super().__init__()
|
27 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
28 |
+
self.bce = BCEWithLogitsLoss(pos_weight=torch.tensor([1.0], device=device), reduction="none")
|
29 |
+
|
30 |
+
def forward(self, predicts_cls: Tensor, targets_cls: Tensor, cls_norm: Tensor) -> Any:
|
31 |
+
return self.bce(predicts_cls, targets_cls).sum() / cls_norm
|
32 |
+
|
33 |
+
|
34 |
+
class BoxLoss(nn.Module):
|
35 |
+
def __init__(self) -> None:
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
def forward(
|
39 |
+
self, predicts_bbox: Tensor, targets_bbox: Tensor, valid_masks: Tensor, box_norm: Tensor, cls_norm: Tensor
|
40 |
+
) -> Any:
|
41 |
+
valid_bbox = valid_masks[..., None].expand(-1, -1, 4)
|
42 |
+
picked_predict = predicts_bbox[valid_bbox].view(-1, 4)
|
43 |
+
picked_targets = targets_bbox[valid_bbox].view(-1, 4)
|
44 |
+
|
45 |
+
iou = calculate_iou(picked_predict, picked_targets, "ciou").diag()
|
46 |
+
loss_iou = 1.0 - iou
|
47 |
+
loss_iou = (loss_iou * box_norm).sum() / cls_norm
|
48 |
+
return loss_iou
|
49 |
+
|
50 |
+
|
51 |
+
class DFLoss(nn.Module):
|
52 |
+
def __init__(self, anchors: Tensor, scaler: Tensor, reg_max: int) -> None:
|
53 |
+
super().__init__()
|
54 |
+
self.anchors = anchors
|
55 |
+
self.scaler = scaler
|
56 |
+
self.reg_max = reg_max
|
57 |
+
|
58 |
+
def forward(
|
59 |
+
self, predicts_anc: Tensor, targets_bbox: Tensor, valid_masks: Tensor, box_norm: Tensor, cls_norm: Tensor
|
60 |
+
) -> Any:
|
61 |
+
valid_bbox = valid_masks[..., None].expand(-1, -1, 4)
|
62 |
+
bbox_lt, bbox_rb = targets_bbox.chunk(2, -1)
|
63 |
+
anchors_norm = (self.anchors / self.scaler[:, None])[None]
|
64 |
+
targets_dist = torch.cat(((anchors_norm - bbox_lt), (bbox_rb - anchors_norm)), -1).clamp(0, self.reg_max - 1.01)
|
65 |
+
picked_targets = targets_dist[valid_bbox].view(-1)
|
66 |
+
picked_predict = predicts_anc[valid_bbox].view(-1, self.reg_max)
|
67 |
+
|
68 |
+
label_left, label_right = picked_targets.floor(), picked_targets.floor() + 1
|
69 |
+
weight_left, weight_right = label_right - picked_targets, picked_targets - label_left
|
70 |
+
|
71 |
+
loss_left = F.cross_entropy(picked_predict, label_left.to(torch.long), reduction="none")
|
72 |
+
loss_right = F.cross_entropy(picked_predict, label_right.to(torch.long), reduction="none")
|
73 |
+
loss_dfl = loss_left * weight_left + loss_right * weight_right
|
74 |
+
loss_dfl = loss_dfl.view(-1, 4).mean(-1)
|
75 |
+
loss_dfl = (loss_dfl * box_norm).sum() / cls_norm
|
76 |
+
return loss_dfl
|
77 |
+
|
78 |
+
|
79 |
+
class YOLOLoss:
|
80 |
+
def __init__(self, cfg: Config) -> None:
|
81 |
+
self.reg_max = cfg.model.anchor.reg_max
|
82 |
+
self.class_num = cfg.hyper.data.class_num
|
83 |
+
self.image_size = list(cfg.hyper.data.image_size)
|
84 |
+
self.strides = cfg.model.anchor.strides
|
85 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
86 |
+
|
87 |
+
self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float16, device=device)
|
88 |
+
self.scale_up = torch.tensor(self.image_size * 2, device=device)
|
89 |
+
|
90 |
+
self.anchors, self.scaler = make_anchor(self.image_size, self.strides, device)
|
91 |
+
|
92 |
+
self.cls = BCELoss()
|
93 |
+
self.dfl = DFLoss(self.anchors, self.scaler, self.reg_max)
|
94 |
+
self.iou = BoxLoss()
|
95 |
+
|
96 |
+
self.matcher = BoxMatcher(cfg.hyper.train.matcher, self.class_num, self.anchors)
|
97 |
+
|
98 |
+
def parse_predicts(self, predicts: List[Tensor]) -> Tensor:
|
99 |
+
"""
|
100 |
+
args:
|
101 |
+
[B x AnchorClass x h1 x w1, B x AnchorClass x h2 x w2, B x AnchorClass x h3 x w3] // AnchorClass = 4 * 16 + 80
|
102 |
+
return:
|
103 |
+
[B x HW x ClassBbox] // HW = h1*w1 + h2*w2 + h3*w3, ClassBox = 80 + 4 (xyXY)
|
104 |
+
"""
|
105 |
+
preds = []
|
106 |
+
for pred in predicts:
|
107 |
+
preds.append(rearrange(pred, "B AC h w -> B (h w) AC")) # B x AC x h x w-> B x hw x AC
|
108 |
+
preds = torch.concat(preds, dim=1) # -> B x (H W) x AC
|
109 |
+
|
110 |
+
preds_anc, preds_cls = torch.split(preds, (self.reg_max * 4, self.class_num), dim=-1)
|
111 |
+
preds_anc = rearrange(preds_anc, "B hw (P R)-> B hw P R", P=4)
|
112 |
+
|
113 |
+
pred_LTRB = preds_anc.softmax(dim=-1) @ self.reverse_reg * self.scaler.view(1, -1, 1)
|
114 |
+
|
115 |
+
lt, rb = pred_LTRB.chunk(2, dim=-1)
|
116 |
+
pred_minXY = self.anchors - lt
|
117 |
+
pred_maxXY = self.anchors + rb
|
118 |
+
predicts = torch.cat([preds_cls, pred_minXY, pred_maxXY], dim=-1)
|
119 |
+
|
120 |
+
return predicts, preds_anc
|
121 |
+
|
122 |
+
def parse_targets(self, targets: Tensor, batch_size: int = 16) -> List[Tensor]:
|
123 |
+
"""
|
124 |
+
return List:
|
125 |
+
"""
|
126 |
+
targets[:, 2:] = transform_bbox(targets[:, 2:], "xycwh -> xyxy") * self.scale_up
|
127 |
+
bbox_num = targets[:, 0].int().bincount()
|
128 |
+
batch_targets = torch.zeros(batch_size, bbox_num.max(), 5, device=targets.device)
|
129 |
+
for instance_idx, bbox_num in enumerate(bbox_num):
|
130 |
+
instance_targets = targets[targets[:, 0] == instance_idx]
|
131 |
+
batch_targets[instance_idx, :bbox_num] = instance_targets[:, 1:].detach()
|
132 |
+
return batch_targets
|
133 |
+
|
134 |
+
def separate_anchor(self, anchors):
|
135 |
+
"""
|
136 |
+
separate anchor and bbouding box
|
137 |
+
"""
|
138 |
+
anchors_cls, anchors_box = torch.split(anchors, (self.class_num, 4), dim=-1)
|
139 |
+
anchors_box = anchors_box / self.scaler[None, :, None]
|
140 |
+
return anchors_cls, anchors_box
|
141 |
+
|
142 |
+
@torch.autocast("cuda" if torch.cuda.is_available() else "cpu")
|
143 |
+
def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
144 |
+
# Batch_Size x (Anchor + Class) x H x W
|
145 |
+
# TODO: check datatype, why targets has a little bit error with origin version
|
146 |
+
predicts, predicts_anc = self.parse_predicts(predicts[0])
|
147 |
+
targets = self.parse_targets(targets, batch_size=predicts.size(0))
|
148 |
+
|
149 |
+
align_targets, valid_masks = self.matcher(targets, predicts)
|
150 |
+
# calculate loss between with instance and predict
|
151 |
+
|
152 |
+
targets_cls, targets_bbox = self.separate_anchor(align_targets)
|
153 |
+
predicts_cls, predicts_bbox = self.separate_anchor(predicts)
|
154 |
+
|
155 |
+
cls_norm = targets_cls.sum()
|
156 |
+
box_norm = targets_cls.sum(-1)[valid_masks]
|
157 |
+
|
158 |
+
## -- CLS -- ##
|
159 |
+
loss_cls = self.cls(predicts_cls, targets_cls, cls_norm)
|
160 |
+
## -- IOU -- ##
|
161 |
+
loss_iou = self.iou(predicts_bbox, targets_bbox, valid_masks, box_norm, cls_norm)
|
162 |
+
## -- DFL -- ##
|
163 |
+
loss_dfl = self.dfl(predicts_anc, targets_bbox, valid_masks, box_norm, cls_norm)
|
164 |
+
|
165 |
+
logger.info("Loss IoU: {:.5f}, DFL: {:.5f}, CLS: {:.5f}", loss_iou, loss_dfl, loss_cls)
|
166 |
+
return loss_iou, loss_dfl, loss_cls
|