π [Merge] branch 'Lightning'
Browse files- tests/test_utils/test_bounding_box_utils.py +52 -11
- yolo/__init__.py +10 -7
- yolo/config/general.yaml +1 -1
- yolo/config/task/inference.yaml +1 -1
- yolo/config/task/validation.yaml +2 -2
- yolo/lazy.py +26 -26
- yolo/model/module.py +1 -1
- yolo/model/yolo.py +7 -7
- yolo/tools/data_loader.py +52 -57
- yolo/tools/dataset_preparation.py +3 -3
- yolo/tools/drawer.py +3 -3
- yolo/tools/loss_functions.py +9 -7
- yolo/tools/solver.py +134 -247
- yolo/utils/bounding_box_utils.py +27 -15
- yolo/utils/dataset_utils.py +15 -1
- yolo/utils/deploy_utils.py +7 -7
- yolo/utils/logger.py +11 -0
- yolo/utils/logging_utils.py +222 -206
- yolo/utils/model_utils.py +9 -4
- yolo/utils/solver_utils.py +3 -2
tests/test_utils/test_bounding_box_utils.py
CHANGED
@@ -146,23 +146,64 @@ def test_anc2box_autoanchor(inference_v7_cfg: Config):
|
|
146 |
|
147 |
|
148 |
def test_bbox_nms():
|
149 |
-
cls_dist = tensor(
|
150 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
)
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
dtype=float32,
|
155 |
)
|
|
|
156 |
nms_cfg = NMSConfig(min_confidence=0.5, min_iou=0.5)
|
157 |
|
158 |
-
|
159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
[
|
161 |
-
[
|
162 |
-
[0
|
163 |
-
]
|
164 |
-
|
165 |
-
|
|
|
|
|
|
|
|
|
166 |
|
167 |
output = bbox_nms(cls_dist, bbox, nms_cfg)
|
168 |
|
|
|
146 |
|
147 |
|
148 |
def test_bbox_nms():
|
149 |
+
cls_dist = torch.tensor(
|
150 |
+
[
|
151 |
+
[
|
152 |
+
[0.7, 0.1, 0.2], # High confidence, class 0
|
153 |
+
[0.3, 0.6, 0.1], # High confidence, class 1
|
154 |
+
[-3.0, -2.0, -1.0], # low confidence, class 2
|
155 |
+
[0.6, 0.2, 0.2], # Medium confidence, class 0
|
156 |
+
],
|
157 |
+
[
|
158 |
+
[0.55, 0.25, 0.2], # Medium confidence, class 0
|
159 |
+
[-4.0, -0.5, -2.0], # low confidence, class 1
|
160 |
+
[0.15, 0.2, 0.65], # Medium confidence, class 2
|
161 |
+
[0.8, 0.1, 0.1], # High confidence, class 0
|
162 |
+
],
|
163 |
+
],
|
164 |
+
dtype=float32,
|
165 |
)
|
166 |
+
|
167 |
+
bbox = torch.tensor(
|
168 |
+
[
|
169 |
+
[
|
170 |
+
[0, 0, 160, 120], # Overlaps with box 4
|
171 |
+
[160, 120, 320, 240],
|
172 |
+
[0, 120, 160, 240],
|
173 |
+
[16, 12, 176, 132],
|
174 |
+
],
|
175 |
+
[
|
176 |
+
[0, 0, 160, 120], # Overlaps with box 4
|
177 |
+
[160, 120, 320, 240],
|
178 |
+
[0, 120, 160, 240],
|
179 |
+
[16, 12, 176, 132],
|
180 |
+
],
|
181 |
+
],
|
182 |
dtype=float32,
|
183 |
)
|
184 |
+
|
185 |
nms_cfg = NMSConfig(min_confidence=0.5, min_iou=0.5)
|
186 |
|
187 |
+
# Batch 1:
|
188 |
+
# - box 1 is kept with class 0 as it has a higher confidence than box 4 i.e. box 4 is filtered out
|
189 |
+
# - box 2 is kept with class 1
|
190 |
+
# - box 3 is rejected by the confidence filter
|
191 |
+
# Batch 2:
|
192 |
+
# - box 4 is kept with class 0 as it has a higher confidence than box 1 i.e. box 1 is filtered out
|
193 |
+
# - box 2 is rejected by the confidence filter
|
194 |
+
# - box 3 is kept with class 2
|
195 |
+
expected_output = torch.tensor(
|
196 |
+
[
|
197 |
[
|
198 |
+
[0.0, 0.0, 0.0, 160.0, 120.0, 0.6682],
|
199 |
+
[1.0, 160.0, 120.0, 320.0, 240.0, 0.6457],
|
200 |
+
],
|
201 |
+
[
|
202 |
+
[0.0, 16.0, 12.0, 176.0, 132.0, 0.6900],
|
203 |
+
[2.0, 0.0, 120.0, 160.0, 240.0, 0.6570],
|
204 |
+
],
|
205 |
+
]
|
206 |
+
)
|
207 |
|
208 |
output = bbox_nms(cls_dist, bbox, nms_cfg)
|
209 |
|
yolo/__init__.py
CHANGED
@@ -2,18 +2,22 @@ from yolo.config.config import Config, NMSConfig
|
|
2 |
from yolo.model.yolo import create_model
|
3 |
from yolo.tools.data_loader import AugmentationComposer, create_dataloader
|
4 |
from yolo.tools.drawer import draw_bboxes
|
5 |
-
from yolo.tools.solver import
|
6 |
from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms, create_converter
|
7 |
from yolo.utils.deploy_utils import FastModelLoader
|
8 |
-
from yolo.utils.logging_utils import
|
|
|
|
|
|
|
|
|
9 |
from yolo.utils.model_utils import PostProccess
|
10 |
|
11 |
all = [
|
12 |
"create_model",
|
13 |
"Config",
|
14 |
-
"
|
15 |
"NMSConfig",
|
16 |
-
"
|
17 |
"validate_log_directory",
|
18 |
"draw_bboxes",
|
19 |
"Vec2Box",
|
@@ -21,10 +25,9 @@ all = [
|
|
21 |
"bbox_nms",
|
22 |
"create_converter",
|
23 |
"AugmentationComposer",
|
|
|
24 |
"create_dataloader",
|
25 |
"FastModelLoader",
|
26 |
-
"
|
27 |
-
"ModelTrainer",
|
28 |
-
"ModelValidator",
|
29 |
"PostProccess",
|
30 |
]
|
|
|
2 |
from yolo.model.yolo import create_model
|
3 |
from yolo.tools.data_loader import AugmentationComposer, create_dataloader
|
4 |
from yolo.tools.drawer import draw_bboxes
|
5 |
+
from yolo.tools.solver import TrainModel
|
6 |
from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms, create_converter
|
7 |
from yolo.utils.deploy_utils import FastModelLoader
|
8 |
+
from yolo.utils.logging_utils import (
|
9 |
+
ImageLogger,
|
10 |
+
YOLORichModelSummary,
|
11 |
+
YOLORichProgressBar,
|
12 |
+
)
|
13 |
from yolo.utils.model_utils import PostProccess
|
14 |
|
15 |
all = [
|
16 |
"create_model",
|
17 |
"Config",
|
18 |
+
"YOLORichProgressBar",
|
19 |
"NMSConfig",
|
20 |
+
"YOLORichModelSummary",
|
21 |
"validate_log_directory",
|
22 |
"draw_bboxes",
|
23 |
"Vec2Box",
|
|
|
25 |
"bbox_nms",
|
26 |
"create_converter",
|
27 |
"AugmentationComposer",
|
28 |
+
"ImageLogger",
|
29 |
"create_dataloader",
|
30 |
"FastModelLoader",
|
31 |
+
"TrainModel",
|
|
|
|
|
32 |
"PostProccess",
|
33 |
]
|
yolo/config/general.yaml
CHANGED
@@ -7,7 +7,7 @@ out_path: runs
|
|
7 |
exist_ok: True
|
8 |
|
9 |
lucky_number: 10
|
10 |
-
use_wandb:
|
11 |
use_tensorboard: False
|
12 |
|
13 |
weight: True # Path to weight or True for auto, False for no pretrained weight
|
|
|
7 |
exist_ok: True
|
8 |
|
9 |
lucky_number: 10
|
10 |
+
use_wandb: True
|
11 |
use_tensorboard: False
|
12 |
|
13 |
weight: True # Path to weight or True for auto, False for no pretrained weight
|
yolo/config/task/inference.yaml
CHANGED
@@ -8,4 +8,4 @@ data:
|
|
8 |
nms:
|
9 |
min_confidence: 0.5
|
10 |
min_iou: 0.5
|
11 |
-
|
|
|
8 |
nms:
|
9 |
min_confidence: 0.5
|
10 |
min_iou: 0.5
|
11 |
+
save_predict: True
|
yolo/config/task/validation.yaml
CHANGED
@@ -8,5 +8,5 @@ data:
|
|
8 |
pin_memory: True
|
9 |
data_augment: {}
|
10 |
nms:
|
11 |
-
min_confidence: 0.
|
12 |
-
min_iou: 0.
|
|
|
8 |
pin_memory: True
|
9 |
data_augment: {}
|
10 |
nms:
|
11 |
+
min_confidence: 0.0001
|
12 |
+
min_iou: 0.7
|
yolo/lazy.py
CHANGED
@@ -2,41 +2,41 @@ import sys
|
|
2 |
from pathlib import Path
|
3 |
|
4 |
import hydra
|
|
|
5 |
|
6 |
project_root = Path(__file__).resolve().parent.parent
|
7 |
sys.path.append(str(project_root))
|
8 |
|
9 |
from yolo.config.config import Config
|
10 |
-
from yolo.
|
11 |
-
from yolo.
|
12 |
-
from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
|
13 |
-
from yolo.utils.bounding_box_utils import create_converter
|
14 |
-
from yolo.utils.deploy_utils import FastModelLoader
|
15 |
-
from yolo.utils.logging_utils import ProgressLogger
|
16 |
-
from yolo.utils.model_utils import get_device
|
17 |
|
18 |
|
19 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
20 |
def main(cfg: Config):
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
40 |
|
41 |
|
42 |
if __name__ == "__main__":
|
|
|
2 |
from pathlib import Path
|
3 |
|
4 |
import hydra
|
5 |
+
from lightning import Trainer
|
6 |
|
7 |
project_root = Path(__file__).resolve().parent.parent
|
8 |
sys.path.append(str(project_root))
|
9 |
|
10 |
from yolo.config.config import Config
|
11 |
+
from yolo.tools.solver import InferenceModel, TrainModel, ValidateModel
|
12 |
+
from yolo.utils.logging_utils import setup
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
|
15 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
16 |
def main(cfg: Config):
|
17 |
+
callbacks, loggers = setup(cfg)
|
18 |
+
|
19 |
+
trainer = Trainer(
|
20 |
+
accelerator="cuda",
|
21 |
+
max_epochs=getattr(cfg.task, "epoch", None),
|
22 |
+
precision="16-mixed",
|
23 |
+
callbacks=callbacks,
|
24 |
+
logger=loggers,
|
25 |
+
log_every_n_steps=1,
|
26 |
+
gradient_clip_val=10,
|
27 |
+
deterministic=True,
|
28 |
+
)
|
29 |
+
|
30 |
+
match cfg.task.task:
|
31 |
+
case "train":
|
32 |
+
model = TrainModel(cfg)
|
33 |
+
trainer.fit(model)
|
34 |
+
case "validation":
|
35 |
+
model = ValidateModel(cfg)
|
36 |
+
trainer.validate(model)
|
37 |
+
case "inference":
|
38 |
+
model = InferenceModel(cfg)
|
39 |
+
trainer.predict(model)
|
40 |
|
41 |
|
42 |
if __name__ == "__main__":
|
yolo/model/module.py
CHANGED
@@ -3,10 +3,10 @@ from typing import Any, Dict, List, Optional, Tuple
|
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
5 |
from einops import rearrange
|
6 |
-
from loguru import logger
|
7 |
from torch import Tensor, nn
|
8 |
from torch.nn.common_types import _size_2_t
|
9 |
|
|
|
10 |
from yolo.utils.module_utils import auto_pad, create_activation_function, round_up
|
11 |
|
12 |
|
|
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
5 |
from einops import rearrange
|
|
|
6 |
from torch import Tensor, nn
|
7 |
from torch.nn.common_types import _size_2_t
|
8 |
|
9 |
+
from yolo.utils.logger import logger
|
10 |
from yolo.utils.module_utils import auto_pad, create_activation_function, round_up
|
11 |
|
12 |
|
yolo/model/yolo.py
CHANGED
@@ -3,12 +3,12 @@ from pathlib import Path
|
|
3 |
from typing import Dict, List, Union
|
4 |
|
5 |
import torch
|
6 |
-
from loguru import logger
|
7 |
from omegaconf import ListConfig, OmegaConf
|
8 |
from torch import nn
|
9 |
|
10 |
from yolo.config.config import ModelConfig, YOLOLayer
|
11 |
from yolo.tools.dataset_preparation import prepare_weight
|
|
|
12 |
from yolo.utils.module_utils import get_layer_map
|
13 |
|
14 |
|
@@ -32,10 +32,10 @@ class YOLO(nn.Module):
|
|
32 |
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
|
33 |
self.layer_index = {}
|
34 |
output_dim, layer_idx = [3], 1
|
35 |
-
logger.info(f"
|
36 |
for arch_name in model_arch:
|
37 |
if model_arch[arch_name]:
|
38 |
-
logger.info(f"
|
39 |
for layer_idx, layer_spec in enumerate(model_arch[arch_name], start=layer_idx):
|
40 |
layer_type, layer_info = next(iter(layer_spec.items()))
|
41 |
layer_args = layer_info.get("args", {})
|
@@ -123,7 +123,7 @@ class YOLO(nn.Module):
|
|
123 |
weights: A OrderedDict containing the new weights.
|
124 |
"""
|
125 |
if isinstance(weights, Path):
|
126 |
-
weights = torch.load(weights, map_location=torch.device("cpu"))
|
127 |
if "model_state_dict" in weights:
|
128 |
weights = weights["model_state_dict"]
|
129 |
|
@@ -144,7 +144,7 @@ class YOLO(nn.Module):
|
|
144 |
|
145 |
for error_name, error_set in error_dict.items():
|
146 |
for weight_name in error_set:
|
147 |
-
logger.warning(f"
|
148 |
|
149 |
self.model.load_state_dict(model_state_dict)
|
150 |
|
@@ -171,7 +171,7 @@ def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True,
|
|
171 |
prepare_weight(weight_path=weight_path)
|
172 |
if weight_path.exists():
|
173 |
model.save_load_weights(weight_path)
|
174 |
-
logger.info("
|
175 |
else:
|
176 |
-
logger.info("
|
177 |
return model
|
|
|
3 |
from typing import Dict, List, Union
|
4 |
|
5 |
import torch
|
|
|
6 |
from omegaconf import ListConfig, OmegaConf
|
7 |
from torch import nn
|
8 |
|
9 |
from yolo.config.config import ModelConfig, YOLOLayer
|
10 |
from yolo.tools.dataset_preparation import prepare_weight
|
11 |
+
from yolo.utils.logger import logger
|
12 |
from yolo.utils.module_utils import get_layer_map
|
13 |
|
14 |
|
|
|
32 |
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
|
33 |
self.layer_index = {}
|
34 |
output_dim, layer_idx = [3], 1
|
35 |
+
logger.info(f":tractor: Building YOLO")
|
36 |
for arch_name in model_arch:
|
37 |
if model_arch[arch_name]:
|
38 |
+
logger.info(f" :building_construction: Building {arch_name}")
|
39 |
for layer_idx, layer_spec in enumerate(model_arch[arch_name], start=layer_idx):
|
40 |
layer_type, layer_info = next(iter(layer_spec.items()))
|
41 |
layer_args = layer_info.get("args", {})
|
|
|
123 |
weights: A OrderedDict containing the new weights.
|
124 |
"""
|
125 |
if isinstance(weights, Path):
|
126 |
+
weights = torch.load(weights, map_location=torch.device("cpu"), weights_only=False)
|
127 |
if "model_state_dict" in weights:
|
128 |
weights = weights["model_state_dict"]
|
129 |
|
|
|
144 |
|
145 |
for error_name, error_set in error_dict.items():
|
146 |
for weight_name in error_set:
|
147 |
+
logger.warning(f":warning: Weight {error_name} for key: {'.'.join(weight_name)}")
|
148 |
|
149 |
self.model.load_state_dict(model_state_dict)
|
150 |
|
|
|
171 |
prepare_weight(weight_path=weight_path)
|
172 |
if weight_path.exists():
|
173 |
model.save_load_weights(weight_path)
|
174 |
+
logger.info(":white_check_mark: Success load model & weight")
|
175 |
else:
|
176 |
+
logger.info(":white_check_mark: Success load model")
|
177 |
return model
|
yolo/tools/data_loader.py
CHANGED
@@ -5,12 +5,10 @@ from typing import Generator, List, Tuple, Union
|
|
5 |
|
6 |
import numpy as np
|
7 |
import torch
|
8 |
-
from loguru import logger
|
9 |
from PIL import Image
|
10 |
from rich.progress import track
|
11 |
from torch import Tensor
|
12 |
from torch.utils.data import DataLoader, Dataset
|
13 |
-
from torch.utils.data.distributed import DistributedSampler
|
14 |
|
15 |
from yolo.config.config import DataConfig, DatasetConfig
|
16 |
from yolo.tools.data_augmentation import *
|
@@ -20,7 +18,9 @@ from yolo.utils.dataset_utils import (
|
|
20 |
create_image_metadata,
|
21 |
locate_label_paths,
|
22 |
scale_segmentation,
|
|
|
23 |
)
|
|
|
24 |
|
25 |
|
26 |
class YoloDataset(Dataset):
|
@@ -32,7 +32,8 @@ class YoloDataset(Dataset):
|
|
32 |
transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
|
33 |
self.transform = AugmentationComposer(transforms, self.image_size)
|
34 |
self.transform.get_more_data = self.get_more_data
|
35 |
-
|
|
|
36 |
|
37 |
def load_data(self, dataset_path: Path, phase_name: str):
|
38 |
"""
|
@@ -45,15 +46,15 @@ class YoloDataset(Dataset):
|
|
45 |
Returns:
|
46 |
dict: The loaded data from the cache for the specified phase.
|
47 |
"""
|
48 |
-
cache_path = dataset_path / f"{phase_name}.
|
49 |
|
50 |
if not cache_path.exists():
|
51 |
-
logger.info("
|
52 |
data = self.filter_data(dataset_path, phase_name)
|
53 |
torch.save(data, cache_path)
|
54 |
else:
|
55 |
data = torch.load(cache_path, weights_only=False)
|
56 |
-
logger.info("
|
57 |
return data
|
58 |
|
59 |
def filter_data(self, dataset_path: Path, phase_name: str) -> list:
|
@@ -103,7 +104,7 @@ class YoloDataset(Dataset):
|
|
103 |
img_path = images_path / image_name
|
104 |
data.append((img_path, labels))
|
105 |
valid_inputs += 1
|
106 |
-
logger.info("Recorded {}/{} valid inputs"
|
107 |
return data
|
108 |
|
109 |
def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[Tensor, None]:
|
@@ -132,9 +133,11 @@ class YoloDataset(Dataset):
|
|
132 |
return torch.zeros((0, 5))
|
133 |
|
134 |
def get_data(self, idx):
|
135 |
-
img_path, bboxes = self.
|
136 |
-
|
137 |
-
|
|
|
|
|
138 |
|
139 |
def get_more_data(self, num: int = 1):
|
140 |
indices = torch.randint(0, len(self), (num,))
|
@@ -143,67 +146,59 @@ class YoloDataset(Dataset):
|
|
143 |
def __getitem__(self, idx) -> Tuple[Image.Image, Tensor, Tensor, List[str]]:
|
144 |
img, bboxes, img_path = self.get_data(idx)
|
145 |
img, bboxes, rev_tensor = self.transform(img, bboxes)
|
|
|
|
|
146 |
return img, bboxes, rev_tensor, img_path
|
147 |
|
148 |
def __len__(self) -> int:
|
149 |
-
return len(self.
|
150 |
-
|
151 |
-
|
152 |
-
class YoloDataLoader(DataLoader):
|
153 |
-
def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train", use_ddp: bool = False):
|
154 |
-
"""Initializes the YoloDataLoader with hydra-config files."""
|
155 |
-
dataset = YoloDataset(data_cfg, dataset_cfg, task)
|
156 |
-
sampler = DistributedSampler(dataset, shuffle=data_cfg.shuffle) if use_ddp else None
|
157 |
-
self.image_size = data_cfg.image_size[0]
|
158 |
-
super().__init__(
|
159 |
-
dataset,
|
160 |
-
batch_size=data_cfg.batch_size,
|
161 |
-
sampler=sampler,
|
162 |
-
shuffle=data_cfg.shuffle and not use_ddp,
|
163 |
-
num_workers=data_cfg.cpu_num,
|
164 |
-
pin_memory=data_cfg.pin_memory,
|
165 |
-
collate_fn=self.collate_fn,
|
166 |
-
)
|
167 |
-
|
168 |
-
def collate_fn(self, batch: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tensor]]:
|
169 |
-
"""
|
170 |
-
A collate function to handle batching of images and their corresponding targets.
|
171 |
|
172 |
-
Args:
|
173 |
-
batch (list of tuples): Each tuple contains:
|
174 |
-
- image (Tensor): The image tensor.
|
175 |
-
- labels (Tensor): The tensor of labels for the image.
|
176 |
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
- A list of tensors, each corresponding to bboxes for each image in the batch.
|
181 |
-
"""
|
182 |
-
batch_size = len(batch)
|
183 |
-
target_sizes = [item[1].size(0) for item in batch]
|
184 |
-
# TODO: Improve readability of these proccess
|
185 |
-
# TODO: remove maxBbox or reduce loss function memory usage
|
186 |
-
batch_targets = torch.zeros(batch_size, min(max(target_sizes), 100), 5)
|
187 |
-
batch_targets[:, :, 0] = -1
|
188 |
-
for idx, target_size in enumerate(target_sizes):
|
189 |
-
batch_targets[idx, : min(target_size, 100)] = batch[idx][1][:100]
|
190 |
-
batch_targets[:, :, 1:] *= self.image_size
|
191 |
|
192 |
-
|
193 |
-
|
194 |
-
|
|
|
195 |
|
196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
|
|
|
|
|
|
|
198 |
|
199 |
-
|
|
|
|
|
|
|
200 |
if task == "inference":
|
201 |
return StreamDataLoader(data_cfg)
|
202 |
|
203 |
if dataset_cfg.auto_download:
|
204 |
prepare_dataset(dataset_cfg, task)
|
205 |
-
|
206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
|
208 |
|
209 |
class StreamDataLoader:
|
|
|
5 |
|
6 |
import numpy as np
|
7 |
import torch
|
|
|
8 |
from PIL import Image
|
9 |
from rich.progress import track
|
10 |
from torch import Tensor
|
11 |
from torch.utils.data import DataLoader, Dataset
|
|
|
12 |
|
13 |
from yolo.config.config import DataConfig, DatasetConfig
|
14 |
from yolo.tools.data_augmentation import *
|
|
|
18 |
create_image_metadata,
|
19 |
locate_label_paths,
|
20 |
scale_segmentation,
|
21 |
+
tensorlize,
|
22 |
)
|
23 |
+
from yolo.utils.logger import logger
|
24 |
|
25 |
|
26 |
class YoloDataset(Dataset):
|
|
|
32 |
transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
|
33 |
self.transform = AugmentationComposer(transforms, self.image_size)
|
34 |
self.transform.get_more_data = self.get_more_data
|
35 |
+
img_paths, bboxes = tensorlize(self.load_data(Path(dataset_cfg.path), phase_name))
|
36 |
+
self.img_paths, self.bboxes = img_paths, bboxes
|
37 |
|
38 |
def load_data(self, dataset_path: Path, phase_name: str):
|
39 |
"""
|
|
|
46 |
Returns:
|
47 |
dict: The loaded data from the cache for the specified phase.
|
48 |
"""
|
49 |
+
cache_path = dataset_path / f"{phase_name}.cache1"
|
50 |
|
51 |
if not cache_path.exists():
|
52 |
+
logger.info(f":factory: Generating {phase_name} cache")
|
53 |
data = self.filter_data(dataset_path, phase_name)
|
54 |
torch.save(data, cache_path)
|
55 |
else:
|
56 |
data = torch.load(cache_path, weights_only=False)
|
57 |
+
logger.info(f":package: Loaded {phase_name} cache")
|
58 |
return data
|
59 |
|
60 |
def filter_data(self, dataset_path: Path, phase_name: str) -> list:
|
|
|
104 |
img_path = images_path / image_name
|
105 |
data.append((img_path, labels))
|
106 |
valid_inputs += 1
|
107 |
+
logger.info(f"Recorded {valid_inputs}/{len(images_list)} valid inputs")
|
108 |
return data
|
109 |
|
110 |
def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[Tensor, None]:
|
|
|
133 |
return torch.zeros((0, 5))
|
134 |
|
135 |
def get_data(self, idx):
|
136 |
+
img_path, bboxes = self.img_paths[idx], self.bboxes[idx]
|
137 |
+
valid_mask = bboxes[:, 0] != -1
|
138 |
+
with Image.open(img_path) as img:
|
139 |
+
img = img.convert("RGB")
|
140 |
+
return img, torch.from_numpy(bboxes[valid_mask]), img_path
|
141 |
|
142 |
def get_more_data(self, num: int = 1):
|
143 |
indices = torch.randint(0, len(self), (num,))
|
|
|
146 |
def __getitem__(self, idx) -> Tuple[Image.Image, Tensor, Tensor, List[str]]:
|
147 |
img, bboxes, img_path = self.get_data(idx)
|
148 |
img, bboxes, rev_tensor = self.transform(img, bboxes)
|
149 |
+
bboxes[:, [1, 3]] *= self.image_size[0]
|
150 |
+
bboxes[:, [2, 4]] *= self.image_size[1]
|
151 |
return img, bboxes, rev_tensor, img_path
|
152 |
|
153 |
def __len__(self) -> int:
|
154 |
+
return len(self.bboxes)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
|
|
|
|
|
|
|
|
156 |
|
157 |
+
def collate_fn(batch: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tensor]]:
|
158 |
+
"""
|
159 |
+
A collate function to handle batching of images and their corresponding targets.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
+
Args:
|
162 |
+
batch (list of tuples): Each tuple contains:
|
163 |
+
- image (Tensor): The image tensor.
|
164 |
+
- labels (Tensor): The tensor of labels for the image.
|
165 |
|
166 |
+
Returns:
|
167 |
+
Tuple[Tensor, List[Tensor]]: A tuple containing:
|
168 |
+
- A tensor of batched images.
|
169 |
+
- A list of tensors, each corresponding to bboxes for each image in the batch.
|
170 |
+
"""
|
171 |
+
batch_size = len(batch)
|
172 |
+
target_sizes = [item[1].size(0) for item in batch]
|
173 |
+
# TODO: Improve readability of these proccess
|
174 |
+
# TODO: remove maxBbox or reduce loss function memory usage
|
175 |
+
batch_targets = torch.zeros(batch_size, min(max(target_sizes), 100), 5)
|
176 |
+
batch_targets[:, :, 0] = -1
|
177 |
+
for idx, target_size in enumerate(target_sizes):
|
178 |
+
batch_targets[idx, : min(target_size, 100)] = batch[idx][1][:100]
|
179 |
|
180 |
+
batch_images, _, batch_reverse, batch_path = zip(*batch)
|
181 |
+
batch_images = torch.stack(batch_images)
|
182 |
+
batch_reverse = torch.stack(batch_reverse)
|
183 |
|
184 |
+
return batch_size, batch_images, batch_targets, batch_reverse, batch_path
|
185 |
+
|
186 |
+
|
187 |
+
def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train"):
|
188 |
if task == "inference":
|
189 |
return StreamDataLoader(data_cfg)
|
190 |
|
191 |
if dataset_cfg.auto_download:
|
192 |
prepare_dataset(dataset_cfg, task)
|
193 |
+
dataset = YoloDataset(data_cfg, dataset_cfg, task)
|
194 |
+
|
195 |
+
return DataLoader(
|
196 |
+
dataset,
|
197 |
+
batch_size=data_cfg.batch_size,
|
198 |
+
num_workers=data_cfg.cpu_num,
|
199 |
+
pin_memory=data_cfg.pin_memory,
|
200 |
+
collate_fn=collate_fn,
|
201 |
+
)
|
202 |
|
203 |
|
204 |
class StreamDataLoader:
|
yolo/tools/dataset_preparation.py
CHANGED
@@ -3,10 +3,10 @@ from pathlib import Path
|
|
3 |
from typing import Optional
|
4 |
|
5 |
import requests
|
6 |
-
from loguru import logger
|
7 |
from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
|
8 |
|
9 |
from yolo.config.config import DatasetConfig
|
|
|
10 |
|
11 |
|
12 |
def download_file(url, destination: Path):
|
@@ -30,7 +30,7 @@ def download_file(url, destination: Path):
|
|
30 |
for data in response.iter_content(chunk_size=1024 * 1024): # 1 MB chunks
|
31 |
file.write(data)
|
32 |
progress.update(task, advance=len(data))
|
33 |
-
logger.info("
|
34 |
|
35 |
|
36 |
def unzip_file(source: Path, destination: Path):
|
@@ -71,7 +71,7 @@ def prepare_dataset(dataset_cfg: DatasetConfig, task: str):
|
|
71 |
|
72 |
final_place.mkdir(parents=True, exist_ok=True)
|
73 |
if check_files(final_place, dataset_args.get("file_num")):
|
74 |
-
logger.info(f"
|
75 |
continue
|
76 |
|
77 |
if not local_zip_path.exists():
|
|
|
3 |
from typing import Optional
|
4 |
|
5 |
import requests
|
|
|
6 |
from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
|
7 |
|
8 |
from yolo.config.config import DatasetConfig
|
9 |
+
from yolo.utils.logger import logger
|
10 |
|
11 |
|
12 |
def download_file(url, destination: Path):
|
|
|
30 |
for data in response.iter_content(chunk_size=1024 * 1024): # 1 MB chunks
|
31 |
file.write(data)
|
32 |
progress.update(task, advance=len(data))
|
33 |
+
logger.info(":white_check_mark: Download completed.")
|
34 |
|
35 |
|
36 |
def unzip_file(source: Path, destination: Path):
|
|
|
71 |
|
72 |
final_place.mkdir(parents=True, exist_ok=True)
|
73 |
if check_files(final_place, dataset_args.get("file_num")):
|
74 |
+
logger.info(f":white_check_mark: Dataset {dataset_type: <12} already verified.")
|
75 |
continue
|
76 |
|
77 |
if not local_zip_path.exists():
|
yolo/tools/drawer.py
CHANGED
@@ -3,12 +3,12 @@ from typing import List, Optional, Union
|
|
3 |
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
-
from loguru import logger
|
7 |
from PIL import Image, ImageDraw, ImageFont
|
8 |
from torchvision.transforms.functional import to_pil_image
|
9 |
|
10 |
from yolo.config.config import ModelConfig
|
11 |
from yolo.model.yolo import YOLO
|
|
|
12 |
|
13 |
|
14 |
def draw_bboxes(
|
@@ -121,6 +121,6 @@ def draw_model(*, model_cfg: ModelConfig = None, model: YOLO = None, v7_base=Fal
|
|
121 |
dot.edge(str(idx), str(jdx))
|
122 |
try:
|
123 |
dot.render("Model-arch", format="png", cleanup=True)
|
124 |
-
logger.info("
|
125 |
except:
|
126 |
-
logger.warning("
|
|
|
3 |
|
4 |
import numpy as np
|
5 |
import torch
|
|
|
6 |
from PIL import Image, ImageDraw, ImageFont
|
7 |
from torchvision.transforms.functional import to_pil_image
|
8 |
|
9 |
from yolo.config.config import ModelConfig
|
10 |
from yolo.model.yolo import YOLO
|
11 |
+
from yolo.utils.logger import logger
|
12 |
|
13 |
|
14 |
def draw_bboxes(
|
|
|
121 |
dot.edge(str(idx), str(jdx))
|
122 |
try:
|
123 |
dot.render("Model-arch", format="png", cleanup=True)
|
124 |
+
logger.info(":artist_palette: Drawing Model Architecture at Model-arch.png")
|
125 |
except:
|
126 |
+
logger.warning(":warning: Could not find graphviz backend, continue without drawing the model architecture")
|
yolo/tools/loss_functions.py
CHANGED
@@ -2,12 +2,12 @@ from typing import Any, Dict, List, Tuple
|
|
2 |
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
5 |
-
from loguru import logger
|
6 |
from torch import Tensor, nn
|
7 |
from torch.nn import BCEWithLogitsLoss
|
8 |
|
9 |
from yolo.config.config import Config, LossConfig
|
10 |
from yolo.utils.bounding_box_utils import BoxMatcher, Vec2Box, calculate_iou
|
|
|
11 |
|
12 |
|
13 |
class BCELoss(nn.Module):
|
@@ -124,17 +124,19 @@ class DualLoss:
|
|
124 |
aux_iou, aux_dfl, aux_cls = self.loss(aux_predicts, targets)
|
125 |
main_iou, main_dfl, main_cls = self.loss(main_predicts, targets)
|
126 |
|
|
|
|
|
|
|
|
|
|
|
127 |
loss_dict = {
|
128 |
-
"
|
129 |
-
"DFLoss": self.dfl_rate * (aux_dfl * self.aux_rate + main_dfl),
|
130 |
-
"BCELoss": self.cls_rate * (aux_cls * self.aux_rate + main_cls),
|
131 |
}
|
132 |
-
|
133 |
-
return loss_sum, loss_dict
|
134 |
|
135 |
|
136 |
def create_loss_function(cfg: Config, vec2box) -> DualLoss:
|
137 |
# TODO: make it flexible, if cfg doesn't contain aux, only use SingleLoss
|
138 |
loss_function = DualLoss(cfg, vec2box)
|
139 |
-
logger.info("
|
140 |
return loss_function
|
|
|
2 |
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
|
|
5 |
from torch import Tensor, nn
|
6 |
from torch.nn import BCEWithLogitsLoss
|
7 |
|
8 |
from yolo.config.config import Config, LossConfig
|
9 |
from yolo.utils.bounding_box_utils import BoxMatcher, Vec2Box, calculate_iou
|
10 |
+
from yolo.utils.logger import logger
|
11 |
|
12 |
|
13 |
class BCELoss(nn.Module):
|
|
|
124 |
aux_iou, aux_dfl, aux_cls = self.loss(aux_predicts, targets)
|
125 |
main_iou, main_dfl, main_cls = self.loss(main_predicts, targets)
|
126 |
|
127 |
+
total_loss = [
|
128 |
+
self.iou_rate * (aux_iou * self.aux_rate + main_iou),
|
129 |
+
self.dfl_rate * (aux_dfl * self.aux_rate + main_dfl),
|
130 |
+
self.cls_rate * (aux_cls * self.aux_rate + main_cls),
|
131 |
+
]
|
132 |
loss_dict = {
|
133 |
+
f"Loss/{name}Loss": value.detach().item() for name, value in zip(["Box", "DFL", "BCE"], total_loss)
|
|
|
|
|
134 |
}
|
135 |
+
return sum(total_loss), loss_dict
|
|
|
136 |
|
137 |
|
138 |
def create_loss_function(cfg: Config, vec2box) -> DualLoss:
|
139 |
# TODO: make it flexible, if cfg doesn't contain aux, only use SingleLoss
|
140 |
loss_function = DualLoss(cfg, vec2box)
|
141 |
+
logger.info(":white_check_mark: Success load loss function")
|
142 |
return loss_function
|
yolo/tools/solver.py
CHANGED
@@ -1,267 +1,154 @@
|
|
1 |
-
import contextlib
|
2 |
-
import io
|
3 |
-
import json
|
4 |
-
import os
|
5 |
import time
|
6 |
-
from collections import defaultdict
|
7 |
from pathlib import Path
|
8 |
-
from typing import Dict, Optional
|
9 |
|
10 |
-
import
|
11 |
-
|
12 |
-
from
|
13 |
-
from
|
14 |
-
from torch.cuda.amp import GradScaler, autocast
|
15 |
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
16 |
-
from torch.utils.data import DataLoader
|
17 |
|
18 |
-
from yolo.config.config import Config
|
19 |
-
from yolo.model.yolo import
|
20 |
-
from yolo.tools.data_loader import
|
21 |
-
from yolo.tools.drawer import draw_bboxes
|
22 |
from yolo.tools.loss_functions import create_loss_function
|
23 |
-
from yolo.utils.bounding_box_utils import
|
24 |
-
from yolo.utils.
|
25 |
-
from yolo.utils.logging_utils import ProgressLogger, log_model_structure
|
26 |
-
from yolo.utils.model_utils import (
|
27 |
-
ExponentialMovingAverage,
|
28 |
-
PostProccess,
|
29 |
-
collect_prediction,
|
30 |
-
create_optimizer,
|
31 |
-
create_scheduler,
|
32 |
-
predicts_to_json,
|
33 |
-
)
|
34 |
-
from yolo.utils.solver_utils import calculate_ap
|
35 |
|
36 |
|
37 |
-
class
|
38 |
-
def __init__(self, cfg: Config
|
39 |
-
|
40 |
-
self.model =
|
41 |
-
self.use_ddp = use_ddp
|
42 |
-
self.vec2box = vec2box
|
43 |
-
self.device = device
|
44 |
-
self.optimizer = create_optimizer(model, train_cfg.optimizer)
|
45 |
-
self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
|
46 |
-
self.loss_fn = create_loss_function(cfg, vec2box)
|
47 |
-
self.progress = progress
|
48 |
-
self.num_epochs = cfg.task.epoch
|
49 |
-
self.mAPs_dict = defaultdict(list)
|
50 |
|
51 |
-
|
52 |
-
self.
|
53 |
|
54 |
-
if not progress.quite_mode:
|
55 |
-
log_model_structure(model.model)
|
56 |
-
draw_model(model=model)
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
)
|
61 |
-
self.
|
62 |
-
|
63 |
-
|
64 |
-
self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
|
65 |
else:
|
66 |
-
self.
|
67 |
-
self.
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
main_predicts = self.vec2box(predicts["Main"])
|
77 |
-
loss, loss_item = self.loss_fn(aux_predicts, main_predicts, targets)
|
78 |
-
|
79 |
-
self.scaler.scale(loss).backward()
|
80 |
-
self.scaler.unscale_(self.optimizer)
|
81 |
-
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
|
82 |
-
self.scaler.step(self.optimizer)
|
83 |
-
self.scaler.update()
|
84 |
-
|
85 |
-
return loss_item
|
86 |
-
|
87 |
-
def train_one_epoch(self, dataloader):
|
88 |
-
self.model.train()
|
89 |
-
total_loss = defaultdict(float)
|
90 |
-
total_samples = 0
|
91 |
-
self.optimizer.next_epoch(len(dataloader))
|
92 |
-
for batch_size, images, targets, *_ in dataloader:
|
93 |
-
self.optimizer.next_batch()
|
94 |
-
loss_each = self.train_one_batch(images, targets)
|
95 |
-
|
96 |
-
for loss_name, loss_val in loss_each.items():
|
97 |
-
if self.use_ddp: # collecting loss for each batch
|
98 |
-
distributed.all_reduce(loss_val, op=distributed.ReduceOp.AVG)
|
99 |
-
total_loss[loss_name] += loss_val.item() * batch_size
|
100 |
-
total_samples += batch_size
|
101 |
-
self.progress.one_batch(loss_each)
|
102 |
-
|
103 |
-
for loss_val in total_loss.values():
|
104 |
-
loss_val /= total_samples
|
105 |
-
|
106 |
-
if self.scheduler:
|
107 |
-
self.scheduler.step()
|
108 |
-
|
109 |
-
return total_loss
|
110 |
-
|
111 |
-
def save_checkpoint(self, epoch_idx: int, file_name: Optional[str] = None):
|
112 |
-
file_name = file_name or f"E{epoch_idx:03d}.pt"
|
113 |
-
file_path = self.weights_dir / file_name
|
114 |
-
|
115 |
-
checkpoint = {
|
116 |
-
"epoch": epoch_idx,
|
117 |
-
"model_state_dict": self.model.state_dict(),
|
118 |
-
"optimizer_state_dict": self.optimizer.state_dict(),
|
119 |
-
}
|
120 |
-
if self.ema:
|
121 |
-
self.ema.apply_shadow()
|
122 |
-
checkpoint["model_state_dict_ema"] = self.model.state_dict()
|
123 |
-
self.ema.restore()
|
124 |
-
|
125 |
-
logger.info(f"πΎ success save at {file_path}")
|
126 |
-
torch.save(checkpoint, file_path)
|
127 |
-
|
128 |
-
def good_epoch(self, mAPs: Dict[str, Tensor]) -> bool:
|
129 |
-
save_flag = True
|
130 |
-
for mAP_key, mAP_val in mAPs.items():
|
131 |
-
self.mAPs_dict[mAP_key].append(mAP_val)
|
132 |
-
if mAP_val < max(self.mAPs_dict[mAP_key]):
|
133 |
-
save_flag = False
|
134 |
-
return save_flag
|
135 |
-
|
136 |
-
def solve(self, dataloader: DataLoader):
|
137 |
-
logger.info("π Start Training!")
|
138 |
-
num_epochs = self.num_epochs
|
139 |
-
|
140 |
-
self.progress.start_train(num_epochs)
|
141 |
-
for epoch_idx in range(num_epochs):
|
142 |
-
if self.use_ddp:
|
143 |
-
dataloader.sampler.set_epoch(epoch_idx)
|
144 |
-
|
145 |
-
self.progress.start_one_epoch(len(dataloader), "Train", self.optimizer, epoch_idx)
|
146 |
-
epoch_loss = self.train_one_epoch(dataloader)
|
147 |
-
self.progress.finish_one_epoch(epoch_loss, epoch_idx=epoch_idx)
|
148 |
-
|
149 |
-
mAPs = self.validator.solve(self.validation_dataloader, epoch_idx=epoch_idx)
|
150 |
-
if mAPs is not None and self.good_epoch(mAPs):
|
151 |
-
self.save_checkpoint(epoch_idx=epoch_idx)
|
152 |
-
# TODO: save model if result are better than before
|
153 |
-
self.progress.finish_train()
|
154 |
-
|
155 |
-
|
156 |
-
class ModelTester:
|
157 |
-
def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device):
|
158 |
-
self.model = model
|
159 |
-
self.device = device
|
160 |
-
self.progress = progress
|
161 |
-
|
162 |
-
self.post_proccess = PostProccess(vec2box, cfg.task.nms)
|
163 |
-
self.save_path = progress.save_path / "images"
|
164 |
-
os.makedirs(self.save_path, exist_ok=True)
|
165 |
-
self.save_predict = getattr(cfg.task, "save_predict", None)
|
166 |
-
self.idx2label = cfg.dataset.class_list
|
167 |
-
|
168 |
-
def solve(self, dataloader: StreamDataLoader):
|
169 |
-
logger.info("π Start Inference!")
|
170 |
-
if isinstance(self.model, torch.nn.Module):
|
171 |
-
self.model.eval()
|
172 |
-
|
173 |
-
if dataloader.is_stream:
|
174 |
-
import cv2
|
175 |
-
import numpy as np
|
176 |
-
|
177 |
-
last_time = time.time()
|
178 |
-
try:
|
179 |
-
for idx, (images, rev_tensor, origin_frame) in enumerate(dataloader):
|
180 |
-
images = images.to(self.device)
|
181 |
-
rev_tensor = rev_tensor.to(self.device)
|
182 |
-
with torch.no_grad():
|
183 |
-
predicts = self.model(images)
|
184 |
-
predicts = self.post_proccess(predicts, rev_tensor)
|
185 |
-
img = draw_bboxes(origin_frame, predicts, idx2label=self.idx2label)
|
186 |
-
|
187 |
-
if dataloader.is_stream:
|
188 |
-
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
189 |
-
fps = 1 / (time.time() - last_time)
|
190 |
-
cv2.putText(img, f"FPS: {fps:.2f}", (0, 15), 0, 0.5, (100, 255, 0), 1, cv2.LINE_AA)
|
191 |
-
last_time = time.time()
|
192 |
-
cv2.imshow("Prediction", img)
|
193 |
-
if cv2.waitKey(1) & 0xFF == ord("q"):
|
194 |
-
break
|
195 |
-
if not self.save_predict:
|
196 |
-
continue
|
197 |
-
if self.save_predict != False:
|
198 |
-
save_image_path = self.save_path / f"frame{idx:03d}.png"
|
199 |
-
img.save(save_image_path)
|
200 |
-
logger.info(f"πΎ Saved visualize image at {save_image_path}")
|
201 |
-
|
202 |
-
except (KeyboardInterrupt, Exception) as e:
|
203 |
-
dataloader.stop_event.set()
|
204 |
-
dataloader.stop()
|
205 |
-
if isinstance(e, KeyboardInterrupt):
|
206 |
-
logger.error("User Keyboard Interrupt")
|
207 |
-
else:
|
208 |
-
raise e
|
209 |
-
dataloader.stop()
|
210 |
|
|
|
|
|
211 |
|
212 |
-
|
213 |
-
|
214 |
-
self
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
vec2box: Vec2Box,
|
219 |
-
progress: ProgressLogger,
|
220 |
-
device,
|
221 |
-
):
|
222 |
-
self.model = model
|
223 |
-
self.device = device
|
224 |
-
self.progress = progress
|
225 |
|
226 |
-
self.
|
227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
self.coco_gt = COCO(json_path)
|
234 |
|
235 |
-
def solve(self, dataloader, epoch_idx=1):
|
236 |
-
# logger.info("π§ͺ Start Validation!")
|
237 |
-
self.model.eval()
|
238 |
-
predict_json, mAPs = [], defaultdict(list)
|
239 |
-
self.progress.start_one_epoch(len(dataloader), task="Validate")
|
240 |
-
for batch_size, images, targets, rev_tensor, img_paths in dataloader:
|
241 |
-
images, targets, rev_tensor = images.to(self.device), targets.to(self.device), rev_tensor.to(self.device)
|
242 |
-
with torch.no_grad():
|
243 |
-
predicts = self.model(images)
|
244 |
-
predicts = self.post_proccess(predicts)
|
245 |
-
for idx, predict in enumerate(predicts):
|
246 |
-
mAP = calculate_map(predict, targets[idx])
|
247 |
-
for mAP_key, mAP_val in mAP.items():
|
248 |
-
mAPs[mAP_key].append(mAP_val)
|
249 |
|
250 |
-
|
251 |
-
|
|
|
|
|
|
|
|
|
252 |
|
253 |
-
|
254 |
-
self.
|
255 |
-
|
|
|
|
|
256 |
|
257 |
-
|
258 |
-
|
259 |
-
if self.progress.local_rank != 0:
|
260 |
-
return
|
261 |
-
json.dump(predict_json, f)
|
262 |
-
if hasattr(self, "coco_gt"):
|
263 |
-
self.progress.start_pycocotools()
|
264 |
-
result = calculate_ap(self.coco_gt, predict_json)
|
265 |
-
self.progress.finish_pycocotools(result, epoch_idx)
|
266 |
|
267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import time
|
|
|
2 |
from pathlib import Path
|
|
|
3 |
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
from lightning import LightningModule
|
7 |
+
from torchmetrics.detection import MeanAveragePrecision
|
|
|
|
|
|
|
8 |
|
9 |
+
from yolo.config.config import Config
|
10 |
+
from yolo.model.yolo import create_model
|
11 |
+
from yolo.tools.data_loader import create_dataloader
|
12 |
+
from yolo.tools.drawer import draw_bboxes
|
13 |
from yolo.tools.loss_functions import create_loss_function
|
14 |
+
from yolo.utils.bounding_box_utils import create_converter, to_metrics_format
|
15 |
+
from yolo.utils.model_utils import PostProccess, create_optimizer, create_scheduler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
+
class BaseModel(LightningModule):
|
19 |
+
def __init__(self, cfg: Config):
|
20 |
+
super().__init__()
|
21 |
+
self.model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
def forward(self, x):
|
24 |
+
return self.model(x)
|
25 |
|
|
|
|
|
|
|
26 |
|
27 |
+
class ValidateModel(BaseModel):
|
28 |
+
def __init__(self, cfg: Config):
|
29 |
+
super().__init__(cfg)
|
30 |
+
self.cfg = cfg
|
31 |
+
if self.cfg.task.task == "validation":
|
32 |
+
self.validation_cfg = self.cfg.task
|
|
|
33 |
else:
|
34 |
+
self.validation_cfg = self.cfg.task.validation
|
35 |
+
self.metric = MeanAveragePrecision(iou_type="bbox", box_format="xyxy")
|
36 |
+
self.metric.warn_on_many_detections = False
|
37 |
+
self.val_loader = create_dataloader(self.validation_cfg.data, self.cfg.dataset, self.validation_cfg.task)
|
38 |
+
|
39 |
+
def setup(self, stage):
|
40 |
+
self.vec2box = create_converter(
|
41 |
+
self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device
|
42 |
+
)
|
43 |
+
self.post_proccess = PostProccess(self.vec2box, self.validation_cfg.nms)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
+
def val_dataloader(self):
|
46 |
+
return self.val_loader
|
47 |
|
48 |
+
def validation_step(self, batch, batch_idx):
|
49 |
+
batch_size, images, targets, rev_tensor, img_paths = batch
|
50 |
+
predicts = self.post_proccess(self(images))
|
51 |
+
batch_metrics = self.metric(
|
52 |
+
[to_metrics_format(predict) for predict in predicts], [to_metrics_format(target) for target in targets]
|
53 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
+
self.log_dict(
|
56 |
+
{
|
57 |
+
"map": batch_metrics["map"],
|
58 |
+
"map_50": batch_metrics["map_50"],
|
59 |
+
},
|
60 |
+
on_step=True,
|
61 |
+
batch_size=batch_size,
|
62 |
+
)
|
63 |
+
return predicts
|
64 |
+
|
65 |
+
def on_validation_epoch_end(self):
|
66 |
+
epoch_metrics = self.metric.compute()
|
67 |
+
del epoch_metrics["classes"]
|
68 |
+
self.log_dict(epoch_metrics, prog_bar=True, rank_zero_only=True)
|
69 |
+
self.log_dict(
|
70 |
+
{"PyCOCO/AP @ .5:.95": epoch_metrics["map"], "PyCOCO/AP @ .5": epoch_metrics["map_50"]}, rank_zero_only=True
|
71 |
+
)
|
72 |
+
self.metric.reset()
|
73 |
+
|
74 |
+
|
75 |
+
class TrainModel(ValidateModel):
|
76 |
+
def __init__(self, cfg: Config):
|
77 |
+
super().__init__(cfg)
|
78 |
+
self.cfg = cfg
|
79 |
+
self.train_loader = create_dataloader(self.cfg.task.data, self.cfg.dataset, self.cfg.task.task)
|
80 |
+
|
81 |
+
def setup(self, stage):
|
82 |
+
super().setup(stage)
|
83 |
+
self.loss_fn = create_loss_function(self.cfg, self.vec2box)
|
84 |
+
|
85 |
+
def train_dataloader(self):
|
86 |
+
return self.train_loader
|
87 |
+
|
88 |
+
def on_train_epoch_start(self):
|
89 |
+
self.trainer.optimizers[0].next_epoch(len(self.train_loader))
|
90 |
+
|
91 |
+
def training_step(self, batch, batch_idx):
|
92 |
+
lr_dict = self.trainer.optimizers[0].next_batch()
|
93 |
+
batch_size, images, targets, *_ = batch
|
94 |
+
predicts = self(images)
|
95 |
+
aux_predicts = self.vec2box(predicts["AUX"])
|
96 |
+
main_predicts = self.vec2box(predicts["Main"])
|
97 |
+
loss, loss_item = self.loss_fn(aux_predicts, main_predicts, targets)
|
98 |
+
self.log_dict(
|
99 |
+
loss_item,
|
100 |
+
prog_bar=True,
|
101 |
+
on_epoch=True,
|
102 |
+
batch_size=batch_size,
|
103 |
+
rank_zero_only=True,
|
104 |
+
)
|
105 |
+
self.log_dict(lr_dict, prog_bar=False, logger=True, on_epoch=False, rank_zero_only=True)
|
106 |
+
return loss * batch_size
|
107 |
|
108 |
+
def configure_optimizers(self):
|
109 |
+
optimizer = create_optimizer(self.model, self.cfg.task.optimizer)
|
110 |
+
scheduler = create_scheduler(optimizer, self.cfg.task.scheduler)
|
111 |
+
return [optimizer], [scheduler]
|
|
|
112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
+
class InferenceModel(BaseModel):
|
115 |
+
def __init__(self, cfg: Config):
|
116 |
+
super().__init__(cfg)
|
117 |
+
self.cfg = cfg
|
118 |
+
# TODO: Add FastModel
|
119 |
+
self.predict_loader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
|
120 |
|
121 |
+
def setup(self, stage):
|
122 |
+
self.vec2box = create_converter(
|
123 |
+
self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device
|
124 |
+
)
|
125 |
+
self.post_process = PostProcess(self.vec2box, self.cfg.task.nms)
|
126 |
|
127 |
+
def predict_dataloader(self):
|
128 |
+
return self.predict_loader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
+
def predict_step(self, batch, batch_idx):
|
131 |
+
images, rev_tensor, origin_frame = batch
|
132 |
+
predicts = self.post_process(self(images), rev_tensor)
|
133 |
+
img = draw_bboxes(origin_frame, predicts, idx2label=self.cfg.dataset.class_list)
|
134 |
+
if getattr(self.predict_loader, "is_stream", None):
|
135 |
+
fps = self._display_stream(img)
|
136 |
+
else:
|
137 |
+
fps = None
|
138 |
+
if getattr(self.cfg.task, "save_predict", None):
|
139 |
+
self._save_image(img, batch_idx)
|
140 |
+
return img, fps
|
141 |
+
|
142 |
+
def _display_stream(self, img):
|
143 |
+
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
144 |
+
fps = 1 / (time.time() - self.trainer.current_epoch_start_time)
|
145 |
+
cv2.putText(img, f"FPS: {fps:.2f}", (0, 15), 0, 0.5, (100, 255, 0), 1, cv2.LINE_AA)
|
146 |
+
cv2.imshow("Prediction", img)
|
147 |
+
if cv2.waitKey(1) & 0xFF == ord("q"):
|
148 |
+
self.trainer.should_stop = True
|
149 |
+
return fps
|
150 |
+
|
151 |
+
def _save_image(self, img, batch_idx):
|
152 |
+
save_image_path = Path(self.logger.save_dir) / f"frame{batch_idx:03d}.png"
|
153 |
+
img.save(save_image_path)
|
154 |
+
print(f"πΎ Saved visualize image at {save_image_path}")
|
yolo/utils/bounding_box_utils.py
CHANGED
@@ -4,17 +4,17 @@ from typing import Dict, List, Optional, Tuple, Union
|
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
6 |
from einops import rearrange
|
7 |
-
from loguru import logger
|
8 |
from torch import Tensor, arange, tensor
|
9 |
from torchvision.ops import batched_nms
|
10 |
|
11 |
from yolo.config.config import AnchorConfig, MatcherConfig, ModelConfig, NMSConfig
|
12 |
from yolo.model.yolo import YOLO
|
|
|
13 |
|
14 |
|
15 |
def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
|
16 |
metrics = metrics.lower()
|
17 |
-
EPS = 1e-
|
18 |
dtype = bbox1.dtype
|
19 |
bbox1 = bbox1.to(torch.float32)
|
20 |
bbox2 = bbox2.to(torch.float32)
|
@@ -69,7 +69,8 @@ def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
|
|
69 |
(bbox2[..., 2] - bbox2[..., 0]) / (bbox2[..., 3] - bbox2[..., 1] + EPS)
|
70 |
)
|
71 |
v = (4 / (math.pi**2)) * (arctan**2)
|
72 |
-
|
|
|
73 |
# Compute CIoU
|
74 |
ciou = diou - alpha * v
|
75 |
return ciou.to(dtype)
|
@@ -129,7 +130,10 @@ def generate_anchors(image_size: List[int], strides: List[int]):
|
|
129 |
shift = stride // 2
|
130 |
h = torch.arange(0, H, stride) + shift
|
131 |
w = torch.arange(0, W, stride) + shift
|
132 |
-
|
|
|
|
|
|
|
133 |
anchor = torch.stack([anchor_w.flatten(), anchor_h.flatten()], dim=-1)
|
134 |
anchors.append(anchor)
|
135 |
all_anchors = torch.cat(anchors, dim=0)
|
@@ -207,7 +211,7 @@ class BoxMatcher:
|
|
207 |
topk_masks = topk_targets > 0
|
208 |
return topk_targets, topk_masks
|
209 |
|
210 |
-
def filter_duplicates(self, target_matrix: Tensor):
|
211 |
"""
|
212 |
Filter the maximum suitability target index of each anchor.
|
213 |
|
@@ -217,9 +221,11 @@ class BoxMatcher:
|
|
217 |
Returns:
|
218 |
unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
|
219 |
"""
|
220 |
-
|
221 |
-
|
222 |
-
|
|
|
|
|
223 |
|
224 |
def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
|
225 |
"""Matches each target to the most suitable anchor.
|
@@ -271,16 +277,15 @@ class BoxMatcher:
|
|
271 |
topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
|
272 |
|
273 |
# delete one anchor pred assign to mutliple gts
|
274 |
-
unique_indices = self.filter_duplicates(
|
275 |
-
|
276 |
-
# TODO: do we need grid_mask? Filter the valid groud truth
|
277 |
-
valid_mask = (grid_mask.sum(dim=-2) * topk_mask.sum(dim=-2)).bool()
|
278 |
|
279 |
align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
|
280 |
align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
|
281 |
align_cls = F.one_hot(align_cls, self.class_num)
|
282 |
|
283 |
# normalize class ditribution
|
|
|
|
|
284 |
max_target = target_matrix.amax(dim=-1, keepdim=True)
|
285 |
max_iou = iou_mat.amax(dim=-1, keepdim=True)
|
286 |
normalize_term = (target_matrix / (max_target + 1e-9)) * max_iou
|
@@ -295,7 +300,7 @@ class Vec2Box:
|
|
295 |
self.device = device
|
296 |
|
297 |
if hasattr(anchor_cfg, "strides"):
|
298 |
-
logger.info(f"
|
299 |
self.strides = anchor_cfg.strides
|
300 |
else:
|
301 |
logger.info("π§Έ Found no stride of model, performed a dummy test for auto-anchor size")
|
@@ -339,7 +344,7 @@ class Anc2Box:
|
|
339 |
self.device = device
|
340 |
|
341 |
if hasattr(anchor_cfg, "strides"):
|
342 |
-
logger.info(f"
|
343 |
self.strides = anchor_cfg.strides
|
344 |
else:
|
345 |
logger.info("π§Έ Found no stride of model, performed a dummy test for auto-anchor size")
|
@@ -413,7 +418,7 @@ def bbox_nms(cls_dist: Tensor, bbox: Tensor, nms_cfg: NMSConfig, confidence: Opt
|
|
413 |
valid_box = bbox[valid_mask.repeat(1, 1, 4)].view(-1, 4)
|
414 |
|
415 |
batch_idx, *_ = torch.where(valid_mask)
|
416 |
-
nms_idx = batched_nms(valid_box,
|
417 |
predicts_nms = []
|
418 |
for idx in range(cls_dist.size(0)):
|
419 |
instance_idx = nms_idx[idx == batch_idx[nms_idx]]
|
@@ -471,3 +476,10 @@ def calculate_map(predictions, ground_truths, iou_thresholds=arange(0.5, 1, 0.05
|
|
471 |
"mAP.5:.95": torch.mean(torch.stack(aps)),
|
472 |
}
|
473 |
return mAP
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
6 |
from einops import rearrange
|
|
|
7 |
from torch import Tensor, arange, tensor
|
8 |
from torchvision.ops import batched_nms
|
9 |
|
10 |
from yolo.config.config import AnchorConfig, MatcherConfig, ModelConfig, NMSConfig
|
11 |
from yolo.model.yolo import YOLO
|
12 |
+
from yolo.utils.logger import logger
|
13 |
|
14 |
|
15 |
def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
|
16 |
metrics = metrics.lower()
|
17 |
+
EPS = 1e-7
|
18 |
dtype = bbox1.dtype
|
19 |
bbox1 = bbox1.to(torch.float32)
|
20 |
bbox2 = bbox2.to(torch.float32)
|
|
|
69 |
(bbox2[..., 2] - bbox2[..., 0]) / (bbox2[..., 3] - bbox2[..., 1] + EPS)
|
70 |
)
|
71 |
v = (4 / (math.pi**2)) * (arctan**2)
|
72 |
+
with torch.no_grad():
|
73 |
+
alpha = v / (v - iou + 1 + EPS)
|
74 |
# Compute CIoU
|
75 |
ciou = diou - alpha * v
|
76 |
return ciou.to(dtype)
|
|
|
130 |
shift = stride // 2
|
131 |
h = torch.arange(0, H, stride) + shift
|
132 |
w = torch.arange(0, W, stride) + shift
|
133 |
+
if torch.__version__ >= "2.3.0":
|
134 |
+
anchor_h, anchor_w = torch.meshgrid(h, w, indexing="ij")
|
135 |
+
else:
|
136 |
+
anchor_h, anchor_w = torch.meshgrid(h, w)
|
137 |
anchor = torch.stack([anchor_w.flatten(), anchor_h.flatten()], dim=-1)
|
138 |
anchors.append(anchor)
|
139 |
all_anchors = torch.cat(anchors, dim=0)
|
|
|
211 |
topk_masks = topk_targets > 0
|
212 |
return topk_targets, topk_masks
|
213 |
|
214 |
+
def filter_duplicates(self, target_matrix: Tensor, topk_mask: Tensor):
|
215 |
"""
|
216 |
Filter the maximum suitability target index of each anchor.
|
217 |
|
|
|
221 |
Returns:
|
222 |
unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
|
223 |
"""
|
224 |
+
duplicates = (topk_mask.sum(1, keepdim=True) > 1).repeat([1, topk_mask.size(1), 1])
|
225 |
+
max_idx = F.one_hot(target_matrix.argmax(1), topk_mask.size(1)).permute(0, 2, 1)
|
226 |
+
topk_mask = torch.where(duplicates, max_idx, topk_mask)
|
227 |
+
unique_indices = topk_mask.argmax(dim=1)
|
228 |
+
return unique_indices[..., None], topk_mask.sum(1), topk_mask
|
229 |
|
230 |
def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
|
231 |
"""Matches each target to the most suitable anchor.
|
|
|
277 |
topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
|
278 |
|
279 |
# delete one anchor pred assign to mutliple gts
|
280 |
+
unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask)
|
|
|
|
|
|
|
281 |
|
282 |
align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
|
283 |
align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
|
284 |
align_cls = F.one_hot(align_cls, self.class_num)
|
285 |
|
286 |
# normalize class ditribution
|
287 |
+
iou_mat *= topk_mask
|
288 |
+
target_matrix *= topk_mask
|
289 |
max_target = target_matrix.amax(dim=-1, keepdim=True)
|
290 |
max_iou = iou_mat.amax(dim=-1, keepdim=True)
|
291 |
normalize_term = (target_matrix / (max_target + 1e-9)) * max_iou
|
|
|
300 |
self.device = device
|
301 |
|
302 |
if hasattr(anchor_cfg, "strides"):
|
303 |
+
logger.info(f":japanese_not_free_of_charge_button: Found stride of model {anchor_cfg.strides}")
|
304 |
self.strides = anchor_cfg.strides
|
305 |
else:
|
306 |
logger.info("π§Έ Found no stride of model, performed a dummy test for auto-anchor size")
|
|
|
344 |
self.device = device
|
345 |
|
346 |
if hasattr(anchor_cfg, "strides"):
|
347 |
+
logger.info(f":japanese_not_free_of_charge_button: Found stride of model {anchor_cfg.strides}")
|
348 |
self.strides = anchor_cfg.strides
|
349 |
else:
|
350 |
logger.info("π§Έ Found no stride of model, performed a dummy test for auto-anchor size")
|
|
|
418 |
valid_box = bbox[valid_mask.repeat(1, 1, 4)].view(-1, 4)
|
419 |
|
420 |
batch_idx, *_ = torch.where(valid_mask)
|
421 |
+
nms_idx = batched_nms(valid_box, valid_con, batch_idx, nms_cfg.min_iou)
|
422 |
predicts_nms = []
|
423 |
for idx in range(cls_dist.size(0)):
|
424 |
instance_idx = nms_idx[idx == batch_idx[nms_idx]]
|
|
|
476 |
"mAP.5:.95": torch.mean(torch.stack(aps)),
|
477 |
}
|
478 |
return mAP
|
479 |
+
|
480 |
+
|
481 |
+
def to_metrics_format(prediction: Tensor) -> Dict[str, Union[float, Tensor]]:
|
482 |
+
bbox = {"boxes": prediction[:, 1:5], "labels": prediction[:, 0].int()}
|
483 |
+
if prediction.size(1) == 6:
|
484 |
+
bbox["scores"] = prediction[:, 5]
|
485 |
+
return bbox
|
yolo/utils/dataset_utils.py
CHANGED
@@ -5,9 +5,10 @@ from pathlib import Path
|
|
5 |
from typing import Any, Dict, List, Optional, Tuple
|
6 |
|
7 |
import numpy as np
|
8 |
-
|
9 |
|
10 |
from yolo.tools.data_conversion import discretize_categories
|
|
|
11 |
|
12 |
|
13 |
def locate_label_paths(dataset_path: Path, phase_name: Path) -> Tuple[Path, Path]:
|
@@ -111,3 +112,16 @@ def scale_segmentation(
|
|
111 |
seg_array_with_cat.append(scaled_flat_seg_data)
|
112 |
|
113 |
return seg_array_with_cat
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
from typing import Any, Dict, List, Optional, Tuple
|
6 |
|
7 |
import numpy as np
|
8 |
+
import torch
|
9 |
|
10 |
from yolo.tools.data_conversion import discretize_categories
|
11 |
+
from yolo.utils.logger import logger
|
12 |
|
13 |
|
14 |
def locate_label_paths(dataset_path: Path, phase_name: Path) -> Tuple[Path, Path]:
|
|
|
112 |
seg_array_with_cat.append(scaled_flat_seg_data)
|
113 |
|
114 |
return seg_array_with_cat
|
115 |
+
|
116 |
+
|
117 |
+
def tensorlize(data):
|
118 |
+
img_paths, bboxes = zip(*data)
|
119 |
+
max_box = max(bbox.size(0) for bbox in bboxes)
|
120 |
+
padded_bbox_list = []
|
121 |
+
for bbox in bboxes:
|
122 |
+
padding = torch.full((max_box, 5), -1, dtype=torch.float32)
|
123 |
+
padding[: bbox.size(0)] = bbox
|
124 |
+
padded_bbox_list.append(padding)
|
125 |
+
bboxes = np.stack(padded_bbox_list)
|
126 |
+
img_paths = np.array(img_paths)
|
127 |
+
return img_paths, bboxes
|
yolo/utils/deploy_utils.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
from pathlib import Path
|
2 |
|
3 |
import torch
|
4 |
-
from loguru import logger
|
5 |
from torch import Tensor
|
6 |
|
7 |
from yolo.config.config import Config
|
8 |
from yolo.model.yolo import create_model
|
|
|
9 |
|
10 |
|
11 |
class FastModelLoader:
|
@@ -21,10 +21,10 @@ class FastModelLoader:
|
|
21 |
|
22 |
def _validate_compiler(self):
|
23 |
if self.compiler not in ["onnx", "trt", "deploy"]:
|
24 |
-
logger.warning(f"
|
25 |
self.compiler = None
|
26 |
if self.cfg.device == "mps" and self.compiler == "trt":
|
27 |
-
logger.warning("
|
28 |
self.compiler = None
|
29 |
|
30 |
def load_model(self, device):
|
@@ -59,7 +59,7 @@ class FastModelLoader:
|
|
59 |
providers = ["CUDAExecutionProvider"]
|
60 |
try:
|
61 |
ort_session = InferenceSession(self.model_path, providers=providers)
|
62 |
-
logger.info("
|
63 |
except Exception as e:
|
64 |
logger.warning(f"π³ Error loading ONNX model: {e}")
|
65 |
ort_session = self._create_onnx_model(providers)
|
@@ -79,7 +79,7 @@ class FastModelLoader:
|
|
79 |
output_names=["output"],
|
80 |
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
|
81 |
)
|
82 |
-
logger.info(f"
|
83 |
return InferenceSession(self.model_path, providers=providers)
|
84 |
|
85 |
def _load_trt_model(self):
|
@@ -88,7 +88,7 @@ class FastModelLoader:
|
|
88 |
try:
|
89 |
model_trt = TRTModule()
|
90 |
model_trt.load_state_dict(torch.load(self.model_path))
|
91 |
-
logger.info("
|
92 |
except FileNotFoundError:
|
93 |
logger.warning(f"π³ No found model weight at {self.model_path}")
|
94 |
model_trt = self._create_trt_model()
|
@@ -102,5 +102,5 @@ class FastModelLoader:
|
|
102 |
logger.info(f"β»οΈ Creating TensorRT model")
|
103 |
model_trt = torch2trt(model.cuda(), [dummy_input])
|
104 |
torch.save(model_trt.state_dict(), self.model_path)
|
105 |
-
logger.info(f"
|
106 |
return model_trt
|
|
|
1 |
from pathlib import Path
|
2 |
|
3 |
import torch
|
|
|
4 |
from torch import Tensor
|
5 |
|
6 |
from yolo.config.config import Config
|
7 |
from yolo.model.yolo import create_model
|
8 |
+
from yolo.utils.logger import logger
|
9 |
|
10 |
|
11 |
class FastModelLoader:
|
|
|
21 |
|
22 |
def _validate_compiler(self):
|
23 |
if self.compiler not in ["onnx", "trt", "deploy"]:
|
24 |
+
logger.warning(f":warning: Compiler '{self.compiler}' is not supported. Using original model.")
|
25 |
self.compiler = None
|
26 |
if self.cfg.device == "mps" and self.compiler == "trt":
|
27 |
+
logger.warning(":red_apple: TensorRT does not support MPS devices. Using original model.")
|
28 |
self.compiler = None
|
29 |
|
30 |
def load_model(self, device):
|
|
|
59 |
providers = ["CUDAExecutionProvider"]
|
60 |
try:
|
61 |
ort_session = InferenceSession(self.model_path, providers=providers)
|
62 |
+
logger.info(":rocket: Using ONNX as MODEL frameworks!")
|
63 |
except Exception as e:
|
64 |
logger.warning(f"π³ Error loading ONNX model: {e}")
|
65 |
ort_session = self._create_onnx_model(providers)
|
|
|
79 |
output_names=["output"],
|
80 |
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
|
81 |
)
|
82 |
+
logger.info(f":inbox_tray: ONNX model saved to {self.model_path}")
|
83 |
return InferenceSession(self.model_path, providers=providers)
|
84 |
|
85 |
def _load_trt_model(self):
|
|
|
88 |
try:
|
89 |
model_trt = TRTModule()
|
90 |
model_trt.load_state_dict(torch.load(self.model_path))
|
91 |
+
logger.info(":rocket: Using TensorRT as MODEL frameworks!")
|
92 |
except FileNotFoundError:
|
93 |
logger.warning(f"π³ No found model weight at {self.model_path}")
|
94 |
model_trt = self._create_trt_model()
|
|
|
102 |
logger.info(f"β»οΈ Creating TensorRT model")
|
103 |
model_trt = torch2trt(model.cuda(), [dummy_input])
|
104 |
torch.save(model_trt.state_dict(), self.model_path)
|
105 |
+
logger.info(f":inbox_tray: TensorRT model saved to {self.model_path}")
|
106 |
return model_trt
|
yolo/utils/logger.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from lightning.pytorch.utilities.rank_zero import rank_zero_only
|
4 |
+
from rich.console import Console
|
5 |
+
from rich.logging import RichHandler
|
6 |
+
|
7 |
+
logger = logging.getLogger("yolo")
|
8 |
+
logger.setLevel(logging.DEBUG)
|
9 |
+
logger.propagate = False
|
10 |
+
if rank_zero_only.rank == 0 and not logger.hasHandlers():
|
11 |
+
logger.addHandler(RichHandler(console=Console(), show_level=True, show_path=True, show_time=True, markup=True))
|
yolo/utils/logging_utils.py
CHANGED
@@ -11,55 +11,39 @@ Example:
|
|
11 |
custom_logger()
|
12 |
"""
|
13 |
|
14 |
-
import
|
15 |
-
import random
|
16 |
-
import sys
|
17 |
from collections import deque
|
|
|
18 |
from pathlib import Path
|
19 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
20 |
|
21 |
import numpy as np
|
22 |
import torch
|
23 |
import wandb
|
24 |
-
import
|
25 |
-
from
|
|
|
|
|
|
|
26 |
from omegaconf import ListConfig
|
|
|
27 |
from rich.console import Console, Group
|
28 |
-
from rich.
|
29 |
-
BarColumn,
|
30 |
-
Progress,
|
31 |
-
SpinnerColumn,
|
32 |
-
TextColumn,
|
33 |
-
TimeRemainingColumn,
|
34 |
-
)
|
35 |
from rich.table import Table
|
|
|
36 |
from torch import Tensor
|
37 |
from torch.nn import ModuleList
|
38 |
-
from
|
39 |
-
from torchvision.transforms.functional import pil_to_tensor
|
40 |
|
41 |
from yolo.config.config import Config, YOLOLayer
|
42 |
from yolo.model.yolo import YOLO
|
43 |
-
from yolo.
|
44 |
from yolo.utils.solver_utils import make_ap_table
|
45 |
|
46 |
|
47 |
-
def custom_logger(quite: bool = False):
|
48 |
-
logger.remove()
|
49 |
-
if quite:
|
50 |
-
return
|
51 |
-
logger.add(
|
52 |
-
sys.stderr,
|
53 |
-
colorize=True,
|
54 |
-
format="<fg #003385>[{time:MM/DD HH:mm:ss}]</> <level>{level: ^8}</level>| <level>{message}</level>",
|
55 |
-
)
|
56 |
-
|
57 |
-
|
58 |
# TODO: should be moved to correct position
|
59 |
def set_seed(seed):
|
60 |
-
|
61 |
-
np.random.seed(seed)
|
62 |
-
torch.manual_seed(seed)
|
63 |
if torch.cuda.is_available():
|
64 |
torch.cuda.manual_seed(seed)
|
65 |
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
|
@@ -67,189 +51,220 @@ def set_seed(seed):
|
|
67 |
torch.backends.cudnn.benchmark = False
|
68 |
|
69 |
|
70 |
-
class
|
71 |
-
def
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
self.save_path = validate_log_directory(cfg, exp_name=cfg.name)
|
77 |
-
|
78 |
-
progress_bar = (
|
79 |
-
SpinnerColumn(),
|
80 |
-
TextColumn("[progress.description]{task.description}"),
|
81 |
-
BarColumn(bar_width=None),
|
82 |
-
TextColumn("{task.completed:.0f}/{task.total:.0f}"),
|
83 |
-
TimeRemainingColumn(),
|
84 |
-
)
|
85 |
-
self.ap_table = Table()
|
86 |
-
# TODO: load maxlen by config files
|
87 |
-
self.ap_past_list = deque(maxlen=5)
|
88 |
-
self.last_result = 0
|
89 |
-
super().__init__(*args, *progress_bar, **kwargs)
|
90 |
-
|
91 |
-
self.use_wandb = cfg.use_wandb
|
92 |
-
if self.use_wandb and self.local_rank == 0:
|
93 |
-
wandb.errors.term._log = custom_wandb_log
|
94 |
-
self.wandb = wandb.init(
|
95 |
-
project="YOLO", resume="allow", mode="online", dir=self.save_path, id=None, name=exp_name
|
96 |
-
)
|
97 |
|
98 |
-
self.use_tensorboard = cfg.use_tensorboard
|
99 |
-
if self.use_tensorboard and self.local_rank == 0:
|
100 |
-
from torch.utils.tensorboard import SummaryWriter
|
101 |
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
-
|
|
|
112 |
|
113 |
-
def
|
114 |
-
|
115 |
-
|
|
|
|
|
116 |
|
117 |
-
|
118 |
-
def start_train(self, num_epochs: int):
|
119 |
-
self.task_epoch = self.add_task(f"[cyan]Start Training {num_epochs} epochs", total=num_epochs)
|
120 |
-
self.update(self.task_epoch, advance=-0.5)
|
121 |
-
|
122 |
-
@rank_check
|
123 |
-
def start_one_epoch(
|
124 |
-
self, num_batches: int, task: str = "Train", optimizer: Optimizer = None, epoch_idx: int = None
|
125 |
-
):
|
126 |
-
self.num_batches = num_batches
|
127 |
-
self.task = task
|
128 |
-
if hasattr(self, "task_epoch"):
|
129 |
-
self.update(self.task_epoch, description=f"[cyan] Preparing Data")
|
130 |
-
|
131 |
-
if optimizer is not None:
|
132 |
-
lr_values = [params["lr"] for params in optimizer.param_groups]
|
133 |
-
lr_names = ["Learning Rate/bias", "Learning Rate/norm", "Learning Rate/conv"]
|
134 |
-
if self.use_wandb:
|
135 |
-
for lr_name, lr_value in zip(lr_names, lr_values):
|
136 |
-
self.wandb.log({lr_name: lr_value}, step=epoch_idx)
|
137 |
-
|
138 |
-
if self.use_tensorboard:
|
139 |
-
for lr_name, lr_value in zip(lr_names, lr_values):
|
140 |
-
self.tb_writer.add_scalar(lr_name, lr_value, global_step=epoch_idx)
|
141 |
-
|
142 |
-
self.batch_task = self.add_task(f"[green] Phase: {task}", total=num_batches)
|
143 |
-
|
144 |
-
@rank_check
|
145 |
-
def one_batch(self, batch_info: Dict[str, Tensor] = None):
|
146 |
-
epoch_descript = "[cyan]" + self.task + "[white] |"
|
147 |
-
batch_descript = "|"
|
148 |
-
if self.task == "Train":
|
149 |
-
self.update(self.task_epoch, advance=1 / self.num_batches)
|
150 |
-
for info_name, info_val in batch_info.items():
|
151 |
-
epoch_descript += f"{info_name: ^9}|"
|
152 |
-
batch_descript += f" {info_val:2.2f} |"
|
153 |
-
self.update(self.batch_task, advance=1, description=f"[green]{self.task} [white]{batch_descript}")
|
154 |
-
if hasattr(self, "task_epoch"):
|
155 |
-
self.update(self.task_epoch, description=epoch_descript)
|
156 |
-
|
157 |
-
@rank_check
|
158 |
-
def finish_one_epoch(self, batch_info: Dict[str, Any] = None, epoch_idx: int = -1):
|
159 |
-
if self.task == "Train":
|
160 |
-
prefix = "Loss"
|
161 |
-
elif self.task == "Validate":
|
162 |
-
prefix = "Metrics"
|
163 |
-
batch_info = {f"{prefix}/{key}": value for key, value in batch_info.items()}
|
164 |
-
if self.use_wandb:
|
165 |
-
self.wandb.log(batch_info, step=epoch_idx)
|
166 |
-
if self.use_tensorboard:
|
167 |
-
for key, value in batch_info.items():
|
168 |
-
self.tb_writer.add_scalar(key, value, epoch_idx)
|
169 |
-
|
170 |
-
self.remove_task(self.batch_task)
|
171 |
-
|
172 |
-
@rank_check
|
173 |
-
def visualize_image(
|
174 |
-
self,
|
175 |
-
images: Optional[Tensor] = None,
|
176 |
-
ground_truth: Optional[Tensor] = None,
|
177 |
-
prediction: Optional[Union[List[Tensor], Tensor]] = None,
|
178 |
-
epoch_idx: int = 0,
|
179 |
-
) -> None:
|
180 |
-
"""
|
181 |
-
Upload the ground truth bounding boxes, predicted bounding boxes, and the original image to wandb or TensorBoard.
|
182 |
-
|
183 |
-
Args:
|
184 |
-
images (Optional[Tensor]): Tensor of images with shape (BZ, 3, 640, 640).
|
185 |
-
ground_truth (Optional[Tensor]): Ground truth bounding boxes with shape (BZ, N, 5) or (N, 5). Defaults to None.
|
186 |
-
prediction (prediction: Optional[Union[List[Tensor], Tensor]]): List of predicted bounding boxes with shape (N, 6) or (N, 6). Defaults to None.
|
187 |
-
epoch_idx (int): Current epoch index. Defaults to 0.
|
188 |
-
"""
|
189 |
-
if images is not None:
|
190 |
-
images = images[0] if images.ndim == 4 else images
|
191 |
-
if self.use_wandb:
|
192 |
-
wandb.log({"Input Image": wandb.Image(images)}, step=epoch_idx)
|
193 |
-
if self.use_tensorboard:
|
194 |
-
self.tb_writer.add_image("Media/Input Image", images, 1)
|
195 |
-
|
196 |
-
if ground_truth is not None:
|
197 |
-
gt_boxes = ground_truth[0] if ground_truth.ndim == 3 else ground_truth
|
198 |
-
if self.use_wandb:
|
199 |
-
wandb.log(
|
200 |
-
{"Ground Truth": wandb.Image(images, boxes={"predictions": {"box_data": log_bbox(gt_boxes)}})},
|
201 |
-
step=epoch_idx,
|
202 |
-
)
|
203 |
-
if self.use_tensorboard:
|
204 |
-
self.tb_writer.add_image("Media/Ground Truth", pil_to_tensor(draw_bboxes(images, gt_boxes)), epoch_idx)
|
205 |
-
|
206 |
-
if prediction is not None:
|
207 |
-
pred_boxes = prediction[0] if isinstance(prediction, list) else prediction
|
208 |
-
if self.use_wandb:
|
209 |
-
wandb.log(
|
210 |
-
{"Prediction": wandb.Image(images, boxes={"predictions": {"box_data": log_bbox(pred_boxes)}})},
|
211 |
-
step=epoch_idx,
|
212 |
-
)
|
213 |
-
if self.use_tensorboard:
|
214 |
-
self.tb_writer.add_image("Media/Prediction", pil_to_tensor(draw_bboxes(images, pred_boxes)), epoch_idx)
|
215 |
-
|
216 |
-
@rank_check
|
217 |
-
def start_pycocotools(self):
|
218 |
-
self.batch_task = self.add_task("[green]Run pycocotools", total=1)
|
219 |
-
|
220 |
-
@rank_check
|
221 |
-
def finish_pycocotools(self, result, epoch_idx=-1):
|
222 |
-
ap_table, ap_main = make_ap_table(result * 100, self.ap_past_list, self.last_result, epoch_idx)
|
223 |
-
self.last_result = np.maximum(result, self.last_result)
|
224 |
-
self.ap_past_list.append((epoch_idx, ap_main))
|
225 |
-
self.ap_table = ap_table
|
226 |
-
|
227 |
-
if self.use_wandb:
|
228 |
-
self.wandb.log({"PyCOCO/AP @ .5:.95": ap_main[2], "PyCOCO/AP @ .5": ap_main[5]})
|
229 |
-
if self.use_tensorboard:
|
230 |
-
# TODO: waiting torch bugs fix, https://github.com/pytorch/pytorch/issues/32651
|
231 |
-
self.tb_writer.add_scalar("PyCOCO/AP @ .5:.95", ap_main[2], epoch_idx)
|
232 |
-
self.tb_writer.add_scalar("PyCOCO/AP @ .5", ap_main[5], epoch_idx)
|
233 |
-
|
234 |
-
self.update(self.batch_task, advance=1)
|
235 |
-
self.refresh()
|
236 |
-
self.remove_task(self.batch_task)
|
237 |
|
238 |
-
|
239 |
-
def finish_train(self):
|
240 |
-
self.remove_task(self.task_epoch)
|
241 |
-
self.stop()
|
242 |
-
if self.use_wandb:
|
243 |
-
self.wandb.finish()
|
244 |
-
if self.use_tensorboard:
|
245 |
-
self.tb_writer.close()
|
246 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
|
248 |
-
|
249 |
-
if silent:
|
250 |
-
return
|
251 |
-
for line in string.split("\n"):
|
252 |
-
logger.opt(raw=not newline, colors=True).info("π " + line)
|
253 |
|
254 |
|
255 |
def log_model_structure(model: Union[ModuleList, YOLOLayer, YOLO]):
|
@@ -279,6 +294,7 @@ def log_model_structure(model: Union[ModuleList, YOLOLayer, YOLO]):
|
|
279 |
console.print(table)
|
280 |
|
281 |
|
|
|
282 |
def validate_log_directory(cfg: Config, exp_name: str) -> Path:
|
283 |
base_path = Path(cfg.out_path, cfg.task.task)
|
284 |
save_path = base_path / exp_name
|
@@ -296,8 +312,8 @@ def validate_log_directory(cfg: Config, exp_name: str) -> Path:
|
|
296 |
)
|
297 |
|
298 |
save_path.mkdir(parents=True, exist_ok=True)
|
299 |
-
logger.
|
300 |
-
logger.
|
301 |
return save_path
|
302 |
|
303 |
|
@@ -332,4 +348,4 @@ def log_bbox(
|
|
332 |
bbox_entry["scores"] = {"confidence": conf[0]}
|
333 |
bbox_list.append(bbox_entry)
|
334 |
|
335 |
-
return bbox_list
|
|
|
11 |
custom_logger()
|
12 |
"""
|
13 |
|
14 |
+
import logging
|
|
|
|
|
15 |
from collections import deque
|
16 |
+
from logging import FileHandler
|
17 |
from pathlib import Path
|
18 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
19 |
|
20 |
import numpy as np
|
21 |
import torch
|
22 |
import wandb
|
23 |
+
from lightning import LightningModule, Trainer, seed_everything
|
24 |
+
from lightning.pytorch.callbacks import Callback, RichModelSummary, RichProgressBar
|
25 |
+
from lightning.pytorch.callbacks.progress.rich_progress import CustomProgress
|
26 |
+
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
|
27 |
+
from lightning.pytorch.utilities import rank_zero_only
|
28 |
from omegaconf import ListConfig
|
29 |
+
from rich import get_console, reconfigure
|
30 |
from rich.console import Console, Group
|
31 |
+
from rich.logging import RichHandler
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
from rich.table import Table
|
33 |
+
from rich.text import Text
|
34 |
from torch import Tensor
|
35 |
from torch.nn import ModuleList
|
36 |
+
from typing_extensions import override
|
|
|
37 |
|
38 |
from yolo.config.config import Config, YOLOLayer
|
39 |
from yolo.model.yolo import YOLO
|
40 |
+
from yolo.utils.logger import logger
|
41 |
from yolo.utils.solver_utils import make_ap_table
|
42 |
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
# TODO: should be moved to correct position
|
45 |
def set_seed(seed):
|
46 |
+
seed_everything(seed)
|
|
|
|
|
47 |
if torch.cuda.is_available():
|
48 |
torch.cuda.manual_seed(seed)
|
49 |
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
|
|
|
51 |
torch.backends.cudnn.benchmark = False
|
52 |
|
53 |
|
54 |
+
class YOLOCustomProgress(CustomProgress):
|
55 |
+
def get_renderable(self):
|
56 |
+
renderable = Group(*self.get_renderables())
|
57 |
+
if hasattr(self, "table"):
|
58 |
+
renderable = Group(*self.get_renderables(), self.table)
|
59 |
+
return renderable
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
|
|
|
|
|
|
61 |
|
62 |
+
class YOLORichProgressBar(RichProgressBar):
|
63 |
+
@override
|
64 |
+
@rank_zero_only
|
65 |
+
def _init_progress(self, trainer: "Trainer") -> None:
|
66 |
+
if self.is_enabled and (self.progress is None or self._progress_stopped):
|
67 |
+
self._reset_progress_bar_ids()
|
68 |
+
reconfigure(**self._console_kwargs)
|
69 |
+
self._console = Console()
|
70 |
+
self._console.clear_live()
|
71 |
+
self.progress = YOLOCustomProgress(
|
72 |
+
*self.configure_columns(trainer),
|
73 |
+
auto_refresh=False,
|
74 |
+
disable=self.is_disabled,
|
75 |
+
console=self._console,
|
76 |
+
)
|
77 |
+
self.progress.start()
|
78 |
+
|
79 |
+
self._progress_stopped = False
|
80 |
+
|
81 |
+
self.max_result = 0
|
82 |
+
self.past_results = deque(maxlen=5)
|
83 |
+
self.progress.table = Table()
|
84 |
+
|
85 |
+
@override
|
86 |
+
def _get_train_description(self, current_epoch: int) -> str:
|
87 |
+
return Text("[cyan]Train [white]|")
|
88 |
+
|
89 |
+
@override
|
90 |
+
@rank_zero_only
|
91 |
+
def on_train_start(self, trainer, pl_module):
|
92 |
+
self._init_progress(trainer)
|
93 |
+
num_epochs = trainer.max_epochs - 1
|
94 |
+
self.task_epoch = self._add_task(
|
95 |
+
total_batches=num_epochs,
|
96 |
+
description=f"[cyan]Start Training {num_epochs} epochs",
|
97 |
+
)
|
98 |
+
self.max_result = 0
|
99 |
+
self.past_results.clear()
|
100 |
+
self.progress.update(self.task_epoch, advance=-0.5)
|
101 |
+
|
102 |
+
@override
|
103 |
+
@rank_zero_only
|
104 |
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch: Any, batch_idx: int):
|
105 |
+
self._update(self.train_progress_bar_id, batch_idx + 1)
|
106 |
+
self._update_metrics(trainer, pl_module)
|
107 |
+
epoch_descript = "[cyan]Train [white]|"
|
108 |
+
batch_descript = "[green]Train [white]|"
|
109 |
+
metrics = self.get_metrics(trainer, pl_module)
|
110 |
+
metrics.pop("v_num")
|
111 |
+
for metrics_name, metrics_val in metrics.items():
|
112 |
+
if "Loss_step" in metrics_name:
|
113 |
+
epoch_descript += f"{metrics_name.removesuffix('_step').split('/')[1]: ^9}|"
|
114 |
+
batch_descript += f" {metrics_val:2.2f} |"
|
115 |
+
|
116 |
+
self.progress.update(self.task_epoch, advance=1 / self.total_train_batches, description=epoch_descript)
|
117 |
+
self.progress.update(self.train_progress_bar_id, description=batch_descript)
|
118 |
+
self.refresh()
|
119 |
|
120 |
+
@override
|
121 |
+
@rank_zero_only
|
122 |
+
def on_train_end(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
|
123 |
+
self._update_metrics(trainer, pl_module)
|
124 |
+
self.progress.remove_task(self.train_progress_bar_id)
|
125 |
+
self.train_progress_bar_id = None
|
126 |
+
|
127 |
+
@override
|
128 |
+
@rank_zero_only
|
129 |
+
def on_validation_end(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
|
130 |
+
if trainer.state.fn == "fit":
|
131 |
+
self._update_metrics(trainer, pl_module)
|
132 |
+
self.reset_dataloader_idx_tracker()
|
133 |
+
all_metrics = self.get_metrics(trainer, pl_module)
|
134 |
+
|
135 |
+
ap_ar_list = [
|
136 |
+
key
|
137 |
+
for key in all_metrics.keys()
|
138 |
+
if key.startswith(("map", "mar")) and not key.endswith(("_step", "_epoch"))
|
139 |
+
]
|
140 |
+
score = np.array([all_metrics[key] for key in ap_ar_list]) * 100
|
141 |
+
|
142 |
+
self.progress.table, ap_main = make_ap_table(score, self.past_results, self.max_result, trainer.current_epoch)
|
143 |
+
self.max_result = np.maximum(score, self.max_result)
|
144 |
+
self.past_results.append((trainer.current_epoch, ap_main))
|
145 |
+
|
146 |
+
@override
|
147 |
+
def refresh(self) -> None:
|
148 |
+
if self.progress:
|
149 |
+
self.progress.refresh()
|
150 |
+
|
151 |
+
@property
|
152 |
+
def validation_description(self) -> str:
|
153 |
+
return "[green]Validation"
|
154 |
+
|
155 |
+
|
156 |
+
class YOLORichModelSummary(RichModelSummary):
|
157 |
+
@staticmethod
|
158 |
+
@override
|
159 |
+
def summarize(
|
160 |
+
summary_data: List[Tuple[str, List[str]]],
|
161 |
+
total_parameters: int,
|
162 |
+
trainable_parameters: int,
|
163 |
+
model_size: float,
|
164 |
+
total_training_modes: Dict[str, int],
|
165 |
+
**summarize_kwargs: Any,
|
166 |
+
) -> None:
|
167 |
+
from lightning.pytorch.utilities.model_summary import get_human_readable_count
|
168 |
+
|
169 |
+
console = get_console()
|
170 |
+
|
171 |
+
header_style: str = summarize_kwargs.get("header_style", "bold magenta")
|
172 |
+
table = Table(header_style=header_style)
|
173 |
+
table.add_column(" ", style="dim")
|
174 |
+
table.add_column("Name", justify="left", no_wrap=True)
|
175 |
+
table.add_column("Type")
|
176 |
+
table.add_column("Params", justify="right")
|
177 |
+
table.add_column("Mode")
|
178 |
+
|
179 |
+
column_names = list(zip(*summary_data))[0]
|
180 |
+
|
181 |
+
for column_name in ["In sizes", "Out sizes"]:
|
182 |
+
if column_name in column_names:
|
183 |
+
table.add_column(column_name, justify="right", style="white")
|
184 |
+
|
185 |
+
rows = list(zip(*(arr[1] for arr in summary_data)))
|
186 |
+
for row in rows:
|
187 |
+
table.add_row(*row)
|
188 |
+
|
189 |
+
console.print(table)
|
190 |
+
|
191 |
+
parameters = []
|
192 |
+
for param in [trainable_parameters, total_parameters - trainable_parameters, total_parameters, model_size]:
|
193 |
+
parameters.append("{:<{}}".format(get_human_readable_count(int(param)), 10))
|
194 |
+
|
195 |
+
grid = Table(header_style=header_style)
|
196 |
+
table.add_column(" ", style="dim")
|
197 |
+
grid.add_column("[bold]Attributes[/]")
|
198 |
+
grid.add_column("Value")
|
199 |
+
|
200 |
+
grid.add_row("[bold]Trainable params[/]", f"{parameters[0]}")
|
201 |
+
grid.add_row("[bold]Non-trainable params[/]", f"{parameters[1]}")
|
202 |
+
grid.add_row("[bold]Total params[/]", f"{parameters[2]}")
|
203 |
+
grid.add_row("[bold]Total estimated model params size (MB)[/]", f"{parameters[3]}")
|
204 |
+
grid.add_row("[bold]Modules in train mode[/]", f"{total_training_modes['train']}")
|
205 |
+
grid.add_row("[bold]Modules in eval mode[/]", f"{total_training_modes['eval']}")
|
206 |
+
|
207 |
+
console.print(grid)
|
208 |
+
|
209 |
+
|
210 |
+
class ImageLogger(Callback):
|
211 |
+
def on_validation_batch_end(self, trainer: Trainer, pl_module, outputs, batch, batch_idx) -> None:
|
212 |
+
if batch_idx != 0:
|
213 |
+
return
|
214 |
+
batch_size, images, targets, rev_tensor, img_paths = batch
|
215 |
+
gt_boxes = targets[0] if targets.ndim == 3 else targets
|
216 |
+
pred_boxes = outputs[0] if isinstance(outputs, list) else outputs
|
217 |
+
images = [images[0]]
|
218 |
+
step = trainer.current_epoch
|
219 |
+
for logger in trainer.loggers:
|
220 |
+
if isinstance(logger, WandbLogger):
|
221 |
+
logger.log_image("Input Image", images, step=step)
|
222 |
+
logger.log_image("Ground Truth", images, step=step, boxes=[log_bbox(gt_boxes)])
|
223 |
+
logger.log_image("Prediction", images, step=step, boxes=[log_bbox(pred_boxes)])
|
224 |
+
|
225 |
+
|
226 |
+
def setup_logger(logger_name):
|
227 |
+
class EmojiFormatter(logging.Formatter):
|
228 |
+
def format(self, record, emoji=":high_voltage:"):
|
229 |
+
return f"{emoji} {super().format(record)}"
|
230 |
+
|
231 |
+
rich_handler = RichHandler(markup=True)
|
232 |
+
rich_handler.setFormatter(EmojiFormatter("%(message)s"))
|
233 |
+
rich_logger = logging.getLogger(logger_name)
|
234 |
+
if rich_logger:
|
235 |
+
rich_logger.handlers.clear()
|
236 |
+
rich_logger.addHandler(rich_handler)
|
237 |
+
|
238 |
+
|
239 |
+
def setup(cfg: Config):
|
240 |
+
# seed_everything(cfg.lucky_number)
|
241 |
+
if hasattr(cfg, "quite"):
|
242 |
+
logger.removeHandler("YOLO_logger")
|
243 |
+
return
|
244 |
|
245 |
+
setup_logger("lightning.fabric")
|
246 |
+
setup_logger("lightning.pytorch")
|
247 |
|
248 |
+
def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=True, silent=False):
|
249 |
+
if silent:
|
250 |
+
return
|
251 |
+
for line in string.split("\n"):
|
252 |
+
logger.info(Text.from_ansi(":globe_with_meridians: " + line))
|
253 |
|
254 |
+
wandb.errors.term._log = custom_wandb_log
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
|
256 |
+
save_path = validate_log_directory(cfg, cfg.name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
|
258 |
+
progress, loggers = [], []
|
259 |
+
progress.append(YOLORichProgressBar())
|
260 |
+
progress.append(YOLORichModelSummary())
|
261 |
+
progress.append(ImageLogger())
|
262 |
+
if cfg.use_tensorboard:
|
263 |
+
loggers.append(TensorBoardLogger(log_graph="all", save_dir=save_path))
|
264 |
+
if cfg.use_wandb:
|
265 |
+
loggers.append(WandbLogger(project="YOLO", name=cfg.name, save_dir=save_path, id=None))
|
266 |
|
267 |
+
return progress, loggers
|
|
|
|
|
|
|
|
|
268 |
|
269 |
|
270 |
def log_model_structure(model: Union[ModuleList, YOLOLayer, YOLO]):
|
|
|
294 |
console.print(table)
|
295 |
|
296 |
|
297 |
+
@rank_zero_only
|
298 |
def validate_log_directory(cfg: Config, exp_name: str) -> Path:
|
299 |
base_path = Path(cfg.out_path, cfg.task.task)
|
300 |
save_path = base_path / exp_name
|
|
|
312 |
)
|
313 |
|
314 |
save_path.mkdir(parents=True, exist_ok=True)
|
315 |
+
logger.info(f"π Created log folder: [blue b u]{save_path}[/]")
|
316 |
+
logger.addHandler(FileHandler(save_path / "output.log"))
|
317 |
return save_path
|
318 |
|
319 |
|
|
|
348 |
bbox_entry["scores"] = {"confidence": conf[0]}
|
349 |
bbox_list.append(bbox_entry)
|
350 |
|
351 |
+
return {"predictions": {"box_data": bbox_list}}
|
yolo/utils/model_utils.py
CHANGED
@@ -4,7 +4,6 @@ from typing import List, Optional, Type, Union
|
|
4 |
|
5 |
import torch
|
6 |
import torch.distributed as dist
|
7 |
-
from loguru import logger
|
8 |
from omegaconf import ListConfig
|
9 |
from torch import Tensor
|
10 |
from torch.optim import Optimizer
|
@@ -13,6 +12,7 @@ from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler
|
|
13 |
from yolo.config.config import IDX_TO_ID, NMSConfig, OptimizerConfig, SchedulerConfig
|
14 |
from yolo.model.yolo import YOLO
|
15 |
from yolo.utils.bounding_box_utils import bbox_nms, transform_bbox
|
|
|
16 |
|
17 |
|
18 |
class ExponentialMovingAverage:
|
@@ -52,9 +52,9 @@ def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
|
|
52 |
conv_params = [p for name, p in model.named_parameters() if "weight" in name and "bn" not in name]
|
53 |
|
54 |
model_parameters = [
|
55 |
-
{"params": bias_params, "weight_decay": 0},
|
56 |
-
{"params": conv_params},
|
57 |
-
{"params": norm_params, "weight_decay": 0},
|
58 |
]
|
59 |
|
60 |
def next_epoch(self, batch_num):
|
@@ -65,12 +65,16 @@ def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
|
|
65 |
|
66 |
def next_batch(self):
|
67 |
self.batch_idx += 1
|
|
|
68 |
for lr_idx, param_group in enumerate(self.param_groups):
|
69 |
min_lr, max_lr = self.min_lr[lr_idx], self.max_lr[lr_idx]
|
70 |
param_group["lr"] = min_lr + (self.batch_idx) * (max_lr - min_lr) / self.batch_num
|
|
|
|
|
71 |
|
72 |
optimizer_class.next_batch = next_batch
|
73 |
optimizer_class.next_epoch = next_epoch
|
|
|
74 |
optimizer = optimizer_class(model_parameters, **optim_cfg.args)
|
75 |
optimizer.max_lr = [0.1, 0, 0]
|
76 |
return optimizer
|
@@ -168,6 +172,7 @@ def predicts_to_json(img_paths, predicts, rev_tensor):
|
|
168 |
batch_json = []
|
169 |
for img_path, bboxes, box_reverse in zip(img_paths, predicts, rev_tensor):
|
170 |
scale, shift = box_reverse.split([1, 4])
|
|
|
171 |
bboxes[:, 1:5] = (bboxes[:, 1:5] - shift[None]) / scale[None]
|
172 |
bboxes[:, 1:5] = transform_bbox(bboxes[:, 1:5], "xyxy -> xywh")
|
173 |
for cls, *pos, conf in bboxes:
|
|
|
4 |
|
5 |
import torch
|
6 |
import torch.distributed as dist
|
|
|
7 |
from omegaconf import ListConfig
|
8 |
from torch import Tensor
|
9 |
from torch.optim import Optimizer
|
|
|
12 |
from yolo.config.config import IDX_TO_ID, NMSConfig, OptimizerConfig, SchedulerConfig
|
13 |
from yolo.model.yolo import YOLO
|
14 |
from yolo.utils.bounding_box_utils import bbox_nms, transform_bbox
|
15 |
+
from yolo.utils.logger import logger
|
16 |
|
17 |
|
18 |
class ExponentialMovingAverage:
|
|
|
52 |
conv_params = [p for name, p in model.named_parameters() if "weight" in name and "bn" not in name]
|
53 |
|
54 |
model_parameters = [
|
55 |
+
{"params": bias_params, "momentum": 0.8, "weight_decay": 0},
|
56 |
+
{"params": conv_params, "momentum": 0.8},
|
57 |
+
{"params": norm_params, "momentum": 0.8, "weight_decay": 0},
|
58 |
]
|
59 |
|
60 |
def next_epoch(self, batch_num):
|
|
|
65 |
|
66 |
def next_batch(self):
|
67 |
self.batch_idx += 1
|
68 |
+
lr_dict = dict()
|
69 |
for lr_idx, param_group in enumerate(self.param_groups):
|
70 |
min_lr, max_lr = self.min_lr[lr_idx], self.max_lr[lr_idx]
|
71 |
param_group["lr"] = min_lr + (self.batch_idx) * (max_lr - min_lr) / self.batch_num
|
72 |
+
lr_dict[f"LR/{lr_idx}"] = param_group["lr"]
|
73 |
+
return lr_dict
|
74 |
|
75 |
optimizer_class.next_batch = next_batch
|
76 |
optimizer_class.next_epoch = next_epoch
|
77 |
+
|
78 |
optimizer = optimizer_class(model_parameters, **optim_cfg.args)
|
79 |
optimizer.max_lr = [0.1, 0, 0]
|
80 |
return optimizer
|
|
|
172 |
batch_json = []
|
173 |
for img_path, bboxes, box_reverse in zip(img_paths, predicts, rev_tensor):
|
174 |
scale, shift = box_reverse.split([1, 4])
|
175 |
+
bboxes = bboxes.clone()
|
176 |
bboxes[:, 1:5] = (bboxes[:, 1:5] - shift[None]) / scale[None]
|
177 |
bboxes[:, 1:5] = transform_bbox(bboxes[:, 1:5], "xyxy -> xywh")
|
178 |
for cls, *pos, conf in bboxes:
|
yolo/utils/solver_utils.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import contextlib
|
2 |
import io
|
|
|
3 |
|
4 |
import numpy as np
|
5 |
from pycocotools.coco import COCO
|
@@ -17,7 +18,7 @@ def calculate_ap(coco_gt: COCO, pd_path):
|
|
17 |
return coco_eval.stats
|
18 |
|
19 |
|
20 |
-
def make_ap_table(score, past_result=[],
|
21 |
ap_table = Table()
|
22 |
ap_table.add_column("Epoch", justify="center", style="white", width=5)
|
23 |
ap_table.add_column("Avg. Precision", justify="left", style="cyan")
|
@@ -30,7 +31,7 @@ def make_ap_table(score, past_result=[], last_score=None, epoch=-1):
|
|
30 |
if past_result:
|
31 |
ap_table.add_row()
|
32 |
|
33 |
-
color = np.where(
|
34 |
|
35 |
this_ap = ("AP @ .5:.95", color[0], score[0], "AP @ .5", color[1], score[1])
|
36 |
metrics = [
|
|
|
1 |
import contextlib
|
2 |
import io
|
3 |
+
from typing import Dict
|
4 |
|
5 |
import numpy as np
|
6 |
from pycocotools.coco import COCO
|
|
|
18 |
return coco_eval.stats
|
19 |
|
20 |
|
21 |
+
def make_ap_table(score: Dict[str, float], past_result=[], max_result=None, epoch=-1):
|
22 |
ap_table = Table()
|
23 |
ap_table.add_column("Epoch", justify="center", style="white", width=5)
|
24 |
ap_table.add_column("Avg. Precision", justify="left", style="cyan")
|
|
|
31 |
if past_result:
|
32 |
ap_table.add_row()
|
33 |
|
34 |
+
color = np.where(max_result <= score, "[green]", "[red]")
|
35 |
|
36 |
this_ap = ("AP @ .5:.95", color[0], score[0], "AP @ .5", color[1], score[1])
|
37 |
metrics = [
|