henry000 commited on
Commit
17274a5
Β·
2 Parent(s): b1f0abc cdd9a11

πŸ”€ [Merge] branch 'main' into MODEL

Browse files
.github/workflows/deploy.yaml CHANGED
@@ -12,7 +12,7 @@ jobs:
12
 
13
  strategy:
14
  matrix:
15
- operating-system: [ubuntu-latest, macos-latest]
16
  python-version: [3.8, '3.10']
17
  fail-fast: false
18
 
@@ -53,17 +53,17 @@ jobs:
53
 
54
  - name: Run Validation
55
  run: |
56
- python yolo/lazy.py task=validation dataset=mock
57
- python yolo/lazy.py task=validation dataset=mock model=v9-s
58
- python yolo/lazy.py task=validation dataset=mock name=AnyNameYouWant
59
 
60
  - name: Run Inference
61
  run: |
62
- python yolo/lazy.py task=inference
63
- python yolo/lazy.py task=inference model=v7
64
- python yolo/lazy.py task=inference +quite=True
65
- python yolo/lazy.py task=inference name=AnyNameYouWant
66
- python yolo/lazy.py task=inference image_size=\[480,640]
67
- python yolo/lazy.py task=inference task.nms.min_confidence=0.1
68
- python yolo/lazy.py task=inference task.fast_inference=deploy
69
- python yolo/lazy.py task=inference task.data.source=tests/data/images/val
 
12
 
13
  strategy:
14
  matrix:
15
+ operating-system: [ubuntu-latest, windows-latest]
16
  python-version: [3.8, '3.10']
17
  fail-fast: false
18
 
 
53
 
54
  - name: Run Validation
55
  run: |
56
+ python yolo/lazy.py task=validation use_wandb=False dataset=mock
57
+ python yolo/lazy.py task=validation use_wandb=False dataset=mock model=v9-s
58
+ python yolo/lazy.py task=validation use_wandb=False dataset=mock name=AnyNameYouWant
59
 
60
  - name: Run Inference
61
  run: |
62
+ python yolo/lazy.py task=inference use_wandb=False
63
+ python yolo/lazy.py task=inference use_wandb=False model=v7
64
+ python yolo/lazy.py task=inference use_wandb=False +quite=True
65
+ python yolo/lazy.py task=inference use_wandb=False name=AnyNameYouWant
66
+ python yolo/lazy.py task=inference use_wandb=False image_size=\[480,640]
67
+ python yolo/lazy.py task=inference use_wandb=False task.nms.min_confidence=0.1
68
+ python yolo/lazy.py task=inference use_wandb=False task.fast_inference=deploy
69
+ python yolo/lazy.py task=inference use_wandb=False task.data.source=tests/data/images/val
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
  einops
2
  graphviz
3
  hydra-core
 
4
  loguru
5
  numpy
6
  opencv-python
 
1
  einops
2
  graphviz
3
  hydra-core
4
+ lightning
5
  loguru
6
  numpy
7
  opencv-python
tests/conftest.py CHANGED
@@ -4,15 +4,16 @@ from pathlib import Path
4
  import pytest
5
  import torch
6
  from hydra import compose, initialize
 
7
 
8
  project_root = Path(__file__).resolve().parent.parent
9
  sys.path.append(str(project_root))
10
 
11
  from yolo import Anc2Box, Config, Vec2Box, create_converter, create_model
12
  from yolo.model.yolo import YOLO
13
- from yolo.tools.data_loader import StreamDataLoader, YoloDataLoader
14
  from yolo.tools.dataset_preparation import prepare_dataset
15
- from yolo.utils.logging_utils import ProgressLogger, set_seed
16
 
17
 
18
  def pytest_configure(config):
@@ -52,18 +53,6 @@ def device():
52
  return torch.device("cuda" if torch.cuda.is_available() else "cpu")
53
 
54
 
55
- @pytest.fixture(scope="session")
56
- def train_progress_logger(train_cfg: Config):
57
- progress_logger = ProgressLogger(train_cfg, exp_name=train_cfg.name)
58
- return progress_logger
59
-
60
-
61
- @pytest.fixture(scope="session")
62
- def validation_progress_logger(validation_cfg: Config):
63
- progress_logger = ProgressLogger(validation_cfg, exp_name=validation_cfg.name)
64
- return progress_logger
65
-
66
-
67
  @pytest.fixture(scope="session")
68
  def model(train_cfg: Config, device) -> YOLO:
69
  model = create_model(train_cfg.model)
@@ -76,6 +65,24 @@ def model_v7(inference_v7_cfg: Config, device) -> YOLO:
76
  return model.to(device)
77
 
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  @pytest.fixture(scope="session")
80
  def vec2box(train_cfg: Config, model: YOLO, device) -> Vec2Box:
81
  vec2box = create_converter(train_cfg.model.name, model, train_cfg.model.anchor, train_cfg.image_size, device)
@@ -93,13 +100,13 @@ def anc2box(inference_v7_cfg: Config, model: YOLO, device) -> Anc2Box:
93
  @pytest.fixture(scope="session")
94
  def train_dataloader(train_cfg: Config):
95
  prepare_dataset(train_cfg.dataset, task="train")
96
- return YoloDataLoader(train_cfg.task.data, train_cfg.dataset, train_cfg.task.task)
97
 
98
 
99
  @pytest.fixture(scope="session")
100
  def validation_dataloader(validation_cfg: Config):
101
  prepare_dataset(validation_cfg.dataset, task="val")
102
- return YoloDataLoader(validation_cfg.task.data, validation_cfg.dataset, validation_cfg.task.task)
103
 
104
 
105
  @pytest.fixture(scope="session")
 
4
  import pytest
5
  import torch
6
  from hydra import compose, initialize
7
+ from lightning import Trainer
8
 
9
  project_root = Path(__file__).resolve().parent.parent
10
  sys.path.append(str(project_root))
11
 
12
  from yolo import Anc2Box, Config, Vec2Box, create_converter, create_model
13
  from yolo.model.yolo import YOLO
14
+ from yolo.tools.data_loader import StreamDataLoader, create_dataloader
15
  from yolo.tools.dataset_preparation import prepare_dataset
16
+ from yolo.utils.logging_utils import set_seed, setup
17
 
18
 
19
  def pytest_configure(config):
 
53
  return torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  @pytest.fixture(scope="session")
57
  def model(train_cfg: Config, device) -> YOLO:
58
  model = create_model(train_cfg.model)
 
65
  return model.to(device)
66
 
67
 
68
+ @pytest.fixture(scope="session")
69
+ def solver(train_cfg: Config) -> Trainer:
70
+ train_cfg.use_wandb = False
71
+ callbacks, loggers, save_path = setup(train_cfg)
72
+ trainer = Trainer(
73
+ accelerator="auto",
74
+ max_epochs=getattr(train_cfg.task, "epoch", None),
75
+ precision="16-mixed",
76
+ callbacks=callbacks,
77
+ logger=loggers,
78
+ log_every_n_steps=1,
79
+ gradient_clip_val=10,
80
+ deterministic=True,
81
+ default_root_dir=save_path,
82
+ )
83
+ return trainer
84
+
85
+
86
  @pytest.fixture(scope="session")
87
  def vec2box(train_cfg: Config, model: YOLO, device) -> Vec2Box:
88
  vec2box = create_converter(train_cfg.model.name, model, train_cfg.model.anchor, train_cfg.image_size, device)
 
100
  @pytest.fixture(scope="session")
101
  def train_dataloader(train_cfg: Config):
102
  prepare_dataset(train_cfg.dataset, task="train")
103
+ return create_dataloader(train_cfg.task.data, train_cfg.dataset, train_cfg.task.task)
104
 
105
 
106
  @pytest.fixture(scope="session")
107
  def validation_dataloader(validation_cfg: Config):
108
  prepare_dataset(validation_cfg.dataset, task="val")
109
+ return create_dataloader(validation_cfg.task.data, validation_cfg.dataset, validation_cfg.task.task)
110
 
111
 
112
  @pytest.fixture(scope="session")
tests/test_tools/test_data_loader.py CHANGED
@@ -1,11 +1,13 @@
1
  import sys
2
  from pathlib import Path
3
 
 
 
4
  project_root = Path(__file__).resolve().parent.parent.parent
5
  sys.path.append(str(project_root))
6
 
7
  from yolo.config.config import Config
8
- from yolo.tools.data_loader import StreamDataLoader, YoloDataLoader, create_dataloader
9
 
10
 
11
  def test_create_dataloader_cache(train_cfg: Config):
@@ -25,7 +27,7 @@ def test_create_dataloader_cache(train_cfg: Config):
25
  assert m_image_paths == l_image_paths
26
 
27
 
28
- def test_training_data_loader_correctness(train_dataloader: YoloDataLoader):
29
  """Test that the training data loader produces correctly shaped data and metadata."""
30
  batch_size, images, _, reverse_tensors, image_paths = next(iter(train_dataloader))
31
  assert batch_size == 2
@@ -38,7 +40,7 @@ def test_training_data_loader_correctness(train_dataloader: YoloDataLoader):
38
  assert list(image_paths) == list(expected_paths)
39
 
40
 
41
- def test_validation_data_loader_correctness(validation_dataloader: YoloDataLoader):
42
  batch_size, images, targets, reverse_tensors, image_paths = next(iter(validation_dataloader))
43
  assert batch_size == 4
44
  assert images.shape == (4, 3, 640, 640)
 
1
  import sys
2
  from pathlib import Path
3
 
4
+ from torch.utils.data import DataLoader
5
+
6
  project_root = Path(__file__).resolve().parent.parent.parent
7
  sys.path.append(str(project_root))
8
 
9
  from yolo.config.config import Config
10
+ from yolo.tools.data_loader import StreamDataLoader, create_dataloader
11
 
12
 
13
  def test_create_dataloader_cache(train_cfg: Config):
 
27
  assert m_image_paths == l_image_paths
28
 
29
 
30
+ def test_training_data_loader_correctness(train_dataloader: DataLoader):
31
  """Test that the training data loader produces correctly shaped data and metadata."""
32
  batch_size, images, _, reverse_tensors, image_paths = next(iter(train_dataloader))
33
  assert batch_size == 2
 
40
  assert list(image_paths) == list(expected_paths)
41
 
42
 
43
+ def test_validation_data_loader_correctness(validation_dataloader: DataLoader):
44
  batch_size, images, targets, reverse_tensors, image_paths = next(iter(validation_dataloader))
45
  assert batch_size == 4
46
  assert images.shape == (4, 3, 640, 640)
tests/test_tools/test_loss_functions.py CHANGED
@@ -1,4 +1,5 @@
1
  import sys
 
2
  from pathlib import Path
3
 
4
  import pytest
@@ -51,6 +52,6 @@ def test_yolo_loss(loss_function, data):
51
  predicts, targets = data
52
  loss, loss_dict = loss_function(predicts, predicts, targets)
53
  assert torch.isnan(loss)
54
- assert torch.isnan(loss_dict["BoxLoss"])
55
- assert torch.isnan(loss_dict["DFLoss"])
56
- assert torch.isinf(loss_dict["BCELoss"])
 
1
  import sys
2
+ from math import isinf, isnan
3
  from pathlib import Path
4
 
5
  import pytest
 
52
  predicts, targets = data
53
  loss, loss_dict = loss_function(predicts, predicts, targets)
54
  assert torch.isnan(loss)
55
+ assert isnan(loss_dict["Loss/BoxLoss"])
56
+ assert isnan(loss_dict["Loss/DFLLoss"])
57
+ assert isinf(loss_dict["Loss/BCELoss"])
tests/test_tools/test_solver.py CHANGED
@@ -1,79 +1,81 @@
1
  import sys
 
2
  from pathlib import Path
3
 
4
  import pytest
5
- from torch import allclose, tensor
 
6
 
7
  project_root = Path(__file__).resolve().parent.parent.parent
8
  sys.path.append(str(project_root))
9
 
10
  from yolo.config.config import Config
11
  from yolo.model.yolo import YOLO
12
- from yolo.tools.data_loader import StreamDataLoader, YoloDataLoader
13
- from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
14
  from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box
15
 
16
 
17
  @pytest.fixture
18
- def model_validator(validation_cfg: Config, model: YOLO, vec2box: Vec2Box, validation_progress_logger, device):
19
- validator = ModelValidator(
20
- validation_cfg.task, validation_cfg.dataset, model, vec2box, validation_progress_logger, device
21
- )
22
  return validator
23
 
24
 
25
- def test_model_validator_initialization(model_validator: ModelValidator):
26
  assert isinstance(model_validator.model, YOLO)
27
- assert hasattr(model_validator, "solve")
28
 
29
 
30
- def test_model_validator_solve_mock_dataset(model_validator: ModelValidator, validation_dataloader: YoloDataLoader):
31
- mAPs = model_validator.solve(validation_dataloader)
32
- except_mAPs = {"mAP.5": tensor(0.6969), "mAP.5:.95": tensor(0.4195)}
33
- assert allclose(mAPs["mAP.5"], except_mAPs["mAP.5"], rtol=0.1)
34
- print(mAPs)
35
- assert allclose(mAPs["mAP.5:.95"], except_mAPs["mAP.5:.95"], rtol=0.1)
 
36
 
37
 
38
  @pytest.fixture
39
- def model_tester(inference_cfg: Config, model: YOLO, vec2box: Vec2Box, validation_progress_logger, device):
40
- tester = ModelTester(inference_cfg, model, vec2box, validation_progress_logger, device)
41
  return tester
42
 
43
 
44
  @pytest.fixture
45
- def modelv7_tester(inference_v7_cfg: Config, model_v7: YOLO, anc2box: Anc2Box, validation_progress_logger, device):
46
- tester = ModelTester(inference_v7_cfg, model_v7, anc2box, validation_progress_logger, device)
47
  return tester
48
 
49
 
50
- def test_model_tester_initialization(model_tester: ModelTester):
51
  assert isinstance(model_tester.model, YOLO)
52
- assert hasattr(model_tester, "solve")
53
 
54
 
55
- def test_model_tester_solve_single_image(model_tester: ModelTester, file_stream_data_loader: StreamDataLoader):
56
- model_tester.solve(file_stream_data_loader)
 
 
57
 
58
 
59
- def test_modelv7_tester_solve_single_image(modelv7_tester: ModelTester, file_stream_data_loader_v7: StreamDataLoader):
60
- modelv7_tester.solve(file_stream_data_loader_v7)
 
 
61
 
62
 
63
  @pytest.fixture
64
- def model_trainer(train_cfg: Config, model: YOLO, vec2box: Vec2Box, train_progress_logger, device):
65
  train_cfg.task.epoch = 2
66
- trainer = ModelTrainer(train_cfg, model, vec2box, train_progress_logger, device, use_ddp=False)
67
  return trainer
68
 
69
 
70
- def test_model_trainer_initialization(model_trainer: ModelTrainer):
71
-
72
  assert isinstance(model_trainer.model, YOLO)
73
- assert hasattr(model_trainer, "solve")
74
- assert model_trainer.optimizer is not None
75
- assert model_trainer.scheduler is not None
76
- assert model_trainer.loss_fn is not None
77
 
78
 
79
  # def test_model_trainer_solve_mock_dataset(model_trainer: ModelTrainer, train_dataloader: YoloDataLoader):
 
1
  import sys
2
+ from math import isclose
3
  from pathlib import Path
4
 
5
  import pytest
6
+ from lightning.pytorch import Trainer
7
+ from torch.utils.data import DataLoader
8
 
9
  project_root = Path(__file__).resolve().parent.parent.parent
10
  sys.path.append(str(project_root))
11
 
12
  from yolo.config.config import Config
13
  from yolo.model.yolo import YOLO
14
+ from yolo.tools.data_loader import StreamDataLoader
15
+ from yolo.tools.solver import InferenceModel, TrainModel, ValidateModel
16
  from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box
17
 
18
 
19
  @pytest.fixture
20
+ def model_validator(validation_cfg: Config):
21
+ validator = ValidateModel(validation_cfg)
 
 
22
  return validator
23
 
24
 
25
+ def test_model_validator_initialization(solver: Trainer, model_validator: ValidateModel):
26
  assert isinstance(model_validator.model, YOLO)
27
+ assert hasattr(solver, "validate")
28
 
29
 
30
+ def test_model_validator_solve_mock_dataset(
31
+ solver: Trainer, model_validator: ValidateModel, validation_dataloader: DataLoader
32
+ ):
33
+ mAPs = solver.validate(model_validator, dataloaders=validation_dataloader)[0]
34
+ except_mAPs = {"map_50": 0.7379, "map": 0.5617}
35
+ assert isclose(mAPs["map_50"], except_mAPs["map_50"], abs_tol=0.1)
36
+ assert isclose(mAPs["map"], except_mAPs["map"], abs_tol=0.1)
37
 
38
 
39
  @pytest.fixture
40
+ def model_tester(inference_cfg: Config):
41
+ tester = InferenceModel(inference_cfg)
42
  return tester
43
 
44
 
45
  @pytest.fixture
46
+ def modelv7_tester(inference_v7_cfg: Config):
47
+ tester = InferenceModel(inference_v7_cfg)
48
  return tester
49
 
50
 
51
+ def test_model_tester_initialization(solver: Trainer, model_tester: InferenceModel):
52
  assert isinstance(model_tester.model, YOLO)
53
+ assert hasattr(solver, "predict")
54
 
55
 
56
+ def test_model_tester_solve_single_image(
57
+ solver: Trainer, model_tester: InferenceModel, file_stream_data_loader: StreamDataLoader
58
+ ):
59
+ solver.predict(model_tester, file_stream_data_loader)
60
 
61
 
62
+ def test_modelv7_tester_solve_single_image(
63
+ solver: Trainer, modelv7_tester: InferenceModel, file_stream_data_loader_v7: StreamDataLoader
64
+ ):
65
+ solver.predict(modelv7_tester, file_stream_data_loader_v7)
66
 
67
 
68
  @pytest.fixture
69
+ def model_trainer(train_cfg: Config):
70
  train_cfg.task.epoch = 2
71
+ trainer = TrainModel(train_cfg)
72
  return trainer
73
 
74
 
75
+ def test_model_trainer_initialization(solver: Trainer, model_trainer: TrainModel):
 
76
  assert isinstance(model_trainer.model, YOLO)
77
+ assert hasattr(solver, "fit")
78
+ assert solver.optimizers is not None
 
 
79
 
80
 
81
  # def test_model_trainer_solve_mock_dataset(model_trainer: ModelTrainer, train_dataloader: YoloDataLoader):
tests/test_utils/test_bounding_box_utils.py CHANGED
@@ -146,23 +146,64 @@ def test_anc2box_autoanchor(inference_v7_cfg: Config):
146
 
147
 
148
  def test_bbox_nms():
149
- cls_dist = tensor(
150
- [[[0.1, 0.7, 0.2], [0.6, 0.3, 0.1]], [[0.4, 0.4, 0.2], [0.5, 0.4, 0.1]]] # Example class distribution
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  )
152
- bbox = tensor(
153
- [[[50, 50, 100, 100], [60, 60, 110, 110]], [[40, 40, 90, 90], [70, 70, 120, 120]]], # Example bounding boxes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  dtype=float32,
155
  )
 
156
  nms_cfg = NMSConfig(min_confidence=0.5, min_iou=0.5)
157
 
158
- expected_output = [
159
- tensor(
 
 
 
 
 
 
 
 
160
  [
161
- [1.0000, 50.0000, 50.0000, 100.0000, 100.0000, 0.6682],
162
- [0.0000, 60.0000, 60.0000, 110.0000, 110.0000, 0.6457],
163
- ]
164
- )
165
- ]
 
 
 
 
166
 
167
  output = bbox_nms(cls_dist, bbox, nms_cfg)
168
 
@@ -175,9 +216,8 @@ def test_calculate_map():
175
  ground_truths = tensor([[0, 50, 50, 150, 150], [0, 30, 30, 100, 100]]) # [class, x1, y1, x2, y2]
176
 
177
  mAP = calculate_map(predictions, ground_truths)
 
 
178
 
179
- expected_ap50 = tensor(0.5)
180
- expected_ap50_95 = tensor(0.2)
181
-
182
- assert isclose(mAP["mAP.5"], expected_ap50, atol=1e-5), f"AP50 mismatch"
183
- assert isclose(mAP["mAP.5:.95"], expected_ap50_95, atol=1e-5), f"Mean AP mismatch"
 
146
 
147
 
148
  def test_bbox_nms():
149
+ cls_dist = torch.tensor(
150
+ [
151
+ [
152
+ [0.7, 0.1, 0.2], # High confidence, class 0
153
+ [0.3, 0.6, 0.1], # High confidence, class 1
154
+ [-3.0, -2.0, -1.0], # low confidence, class 2
155
+ [0.6, 0.2, 0.2], # Medium confidence, class 0
156
+ ],
157
+ [
158
+ [0.55, 0.25, 0.2], # Medium confidence, class 0
159
+ [-4.0, -0.5, -2.0], # low confidence, class 1
160
+ [0.15, 0.2, 0.65], # Medium confidence, class 2
161
+ [0.8, 0.1, 0.1], # High confidence, class 0
162
+ ],
163
+ ],
164
+ dtype=float32,
165
  )
166
+
167
+ bbox = torch.tensor(
168
+ [
169
+ [
170
+ [0, 0, 160, 120], # Overlaps with box 4
171
+ [160, 120, 320, 240],
172
+ [0, 120, 160, 240],
173
+ [16, 12, 176, 132],
174
+ ],
175
+ [
176
+ [0, 0, 160, 120], # Overlaps with box 4
177
+ [160, 120, 320, 240],
178
+ [0, 120, 160, 240],
179
+ [16, 12, 176, 132],
180
+ ],
181
+ ],
182
  dtype=float32,
183
  )
184
+
185
  nms_cfg = NMSConfig(min_confidence=0.5, min_iou=0.5)
186
 
187
+ # Batch 1:
188
+ # - box 1 is kept with class 0 as it has a higher confidence than box 4 i.e. box 4 is filtered out
189
+ # - box 2 is kept with class 1
190
+ # - box 3 is rejected by the confidence filter
191
+ # Batch 2:
192
+ # - box 4 is kept with class 0 as it has a higher confidence than box 1 i.e. box 1 is filtered out
193
+ # - box 2 is rejected by the confidence filter
194
+ # - box 3 is kept with class 2
195
+ expected_output = torch.tensor(
196
+ [
197
  [
198
+ [0.0, 0.0, 0.0, 160.0, 120.0, 0.6682],
199
+ [1.0, 160.0, 120.0, 320.0, 240.0, 0.6457],
200
+ ],
201
+ [
202
+ [0.0, 16.0, 12.0, 176.0, 132.0, 0.6900],
203
+ [2.0, 0.0, 120.0, 160.0, 240.0, 0.6570],
204
+ ],
205
+ ]
206
+ )
207
 
208
  output = bbox_nms(cls_dist, bbox, nms_cfg)
209
 
 
216
  ground_truths = tensor([[0, 50, 50, 150, 150], [0, 30, 30, 100, 100]]) # [class, x1, y1, x2, y2]
217
 
218
  mAP = calculate_map(predictions, ground_truths)
219
+ expected_ap50 = tensor(0.5050)
220
+ expected_ap50_95 = tensor(0.2020)
221
 
222
+ assert isclose(mAP["map_50"], expected_ap50, atol=1e-4), f"AP50 mismatch"
223
+ assert isclose(mAP["map"], expected_ap50_95, atol=1e-4), f"Mean AP mismatch"
 
 
 
yolo/__init__.py CHANGED
@@ -2,18 +2,22 @@ from yolo.config.config import Config, NMSConfig
2
  from yolo.model.yolo import create_model
3
  from yolo.tools.data_loader import AugmentationComposer, create_dataloader
4
  from yolo.tools.drawer import draw_bboxes
5
- from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
6
  from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms, create_converter
7
  from yolo.utils.deploy_utils import FastModelLoader
8
- from yolo.utils.logging_utils import ProgressLogger, custom_logger
9
- from yolo.utils.model_utils import PostProccess
 
 
 
 
10
 
11
  all = [
12
  "create_model",
13
  "Config",
14
- "ProgressLogger",
15
  "NMSConfig",
16
- "custom_logger",
17
  "validate_log_directory",
18
  "draw_bboxes",
19
  "Vec2Box",
@@ -21,10 +25,9 @@ all = [
21
  "bbox_nms",
22
  "create_converter",
23
  "AugmentationComposer",
 
24
  "create_dataloader",
25
  "FastModelLoader",
26
- "ModelTester",
27
- "ModelTrainer",
28
- "ModelValidator",
29
- "PostProccess",
30
  ]
 
2
  from yolo.model.yolo import create_model
3
  from yolo.tools.data_loader import AugmentationComposer, create_dataloader
4
  from yolo.tools.drawer import draw_bboxes
5
+ from yolo.tools.solver import TrainModel
6
  from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms, create_converter
7
  from yolo.utils.deploy_utils import FastModelLoader
8
+ from yolo.utils.logging_utils import (
9
+ ImageLogger,
10
+ YOLORichModelSummary,
11
+ YOLORichProgressBar,
12
+ )
13
+ from yolo.utils.model_utils import PostProcess
14
 
15
  all = [
16
  "create_model",
17
  "Config",
18
+ "YOLORichProgressBar",
19
  "NMSConfig",
20
+ "YOLORichModelSummary",
21
  "validate_log_directory",
22
  "draw_bboxes",
23
  "Vec2Box",
 
25
  "bbox_nms",
26
  "create_converter",
27
  "AugmentationComposer",
28
+ "ImageLogger",
29
  "create_dataloader",
30
  "FastModelLoader",
31
+ "TrainModel",
32
+ "PostProcess",
 
 
33
  ]
yolo/config/general.yaml CHANGED
@@ -7,7 +7,7 @@ out_path: runs
7
  exist_ok: True
8
 
9
  lucky_number: 10
10
- use_wandb: False
11
  use_tensorboard: False
12
 
13
  weight: True # Path to weight or True for auto, False for no pretrained weight
 
7
  exist_ok: True
8
 
9
  lucky_number: 10
10
+ use_wandb: True
11
  use_tensorboard: False
12
 
13
  weight: True # Path to weight or True for auto, False for no pretrained weight
yolo/config/task/inference.yaml CHANGED
@@ -8,4 +8,4 @@ data:
8
  nms:
9
  min_confidence: 0.5
10
  min_iou: 0.5
11
- # save_predict: True
 
8
  nms:
9
  min_confidence: 0.5
10
  min_iou: 0.5
11
+ save_predict: True
yolo/config/task/validation.yaml CHANGED
@@ -8,5 +8,5 @@ data:
8
  pin_memory: True
9
  data_augment: {}
10
  nms:
11
- min_confidence: 0.05
12
- min_iou: 0.9
 
8
  pin_memory: True
9
  data_augment: {}
10
  nms:
11
+ min_confidence: 0.0001
12
+ min_iou: 0.7
yolo/lazy.py CHANGED
@@ -2,41 +2,42 @@ import sys
2
  from pathlib import Path
3
 
4
  import hydra
 
5
 
6
  project_root = Path(__file__).resolve().parent.parent
7
  sys.path.append(str(project_root))
8
 
9
  from yolo.config.config import Config
10
- from yolo.model.yolo import create_model
11
- from yolo.tools.data_loader import create_dataloader
12
- from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
13
- from yolo.utils.bounding_box_utils import create_converter
14
- from yolo.utils.deploy_utils import FastModelLoader
15
- from yolo.utils.logging_utils import ProgressLogger
16
- from yolo.utils.model_utils import get_device
17
 
18
 
19
  @hydra.main(config_path="config", config_name="config", version_base=None)
20
  def main(cfg: Config):
21
- progress = ProgressLogger(cfg, exp_name=cfg.name)
22
- device, use_ddp = get_device(cfg.device)
23
- dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task, use_ddp)
24
- if getattr(cfg.task, "fast_inference", False):
25
- model = FastModelLoader(cfg).load_model(device)
26
- else:
27
- model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight)
28
- model = model.to(device)
29
-
30
- converter = create_converter(cfg.model.name, model, cfg.model.anchor, cfg.image_size, device)
 
 
 
 
31
 
32
  if cfg.task.task == "train":
33
- solver = ModelTrainer(cfg, model, converter, progress, device, use_ddp)
 
34
  if cfg.task.task == "validation":
35
- solver = ModelValidator(cfg.task, cfg.dataset, model, converter, progress, device)
 
36
  if cfg.task.task == "inference":
37
- solver = ModelTester(cfg, model, converter, progress, device)
38
- progress.start()
39
- solver.solve(dataloader)
40
 
41
 
42
  if __name__ == "__main__":
 
2
  from pathlib import Path
3
 
4
  import hydra
5
+ from lightning import Trainer
6
 
7
  project_root = Path(__file__).resolve().parent.parent
8
  sys.path.append(str(project_root))
9
 
10
  from yolo.config.config import Config
11
+ from yolo.tools.solver import InferenceModel, TrainModel, ValidateModel
12
+ from yolo.utils.logging_utils import setup
 
 
 
 
 
13
 
14
 
15
  @hydra.main(config_path="config", config_name="config", version_base=None)
16
  def main(cfg: Config):
17
+ callbacks, loggers, save_path = setup(cfg)
18
+
19
+ trainer = Trainer(
20
+ accelerator="auto",
21
+ max_epochs=getattr(cfg.task, "epoch", None),
22
+ precision="16-mixed",
23
+ callbacks=callbacks,
24
+ logger=loggers,
25
+ log_every_n_steps=1,
26
+ gradient_clip_val=10,
27
+ deterministic=True,
28
+ enable_progress_bar=not getattr(cfg, "quite", False),
29
+ default_root_dir=save_path,
30
+ )
31
 
32
  if cfg.task.task == "train":
33
+ model = TrainModel(cfg)
34
+ trainer.fit(model)
35
  if cfg.task.task == "validation":
36
+ model = ValidateModel(cfg)
37
+ trainer.validate(model)
38
  if cfg.task.task == "inference":
39
+ model = InferenceModel(cfg)
40
+ trainer.predict(model)
 
41
 
42
 
43
  if __name__ == "__main__":
yolo/model/module.py CHANGED
@@ -3,10 +3,10 @@ from typing import Any, Dict, List, Optional, Tuple
3
  import torch
4
  import torch.nn.functional as F
5
  from einops import rearrange
6
- from loguru import logger
7
  from torch import Tensor, nn
8
  from torch.nn.common_types import _size_2_t
9
 
 
10
  from yolo.utils.module_utils import auto_pad, create_activation_function, round_up
11
 
12
 
 
3
  import torch
4
  import torch.nn.functional as F
5
  from einops import rearrange
 
6
  from torch import Tensor, nn
7
  from torch.nn.common_types import _size_2_t
8
 
9
+ from yolo.utils.logger import logger
10
  from yolo.utils.module_utils import auto_pad, create_activation_function, round_up
11
 
12
 
yolo/model/yolo.py CHANGED
@@ -3,12 +3,12 @@ from pathlib import Path
3
  from typing import Dict, List, Union
4
 
5
  import torch
6
- from loguru import logger
7
  from omegaconf import ListConfig, OmegaConf
8
  from torch import nn
9
 
10
  from yolo.config.config import ModelConfig, YOLOLayer
11
  from yolo.tools.dataset_preparation import prepare_weight
 
12
  from yolo.utils.module_utils import get_layer_map
13
 
14
 
@@ -32,10 +32,10 @@ class YOLO(nn.Module):
32
  def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
33
  self.layer_index = {}
34
  output_dim, layer_idx = [3], 1
35
- logger.info(f"🚜 Building YOLO")
36
  for arch_name in model_arch:
37
  if model_arch[arch_name]:
38
- logger.info(f" πŸ—οΈ Building {arch_name}")
39
  for layer_idx, layer_spec in enumerate(model_arch[arch_name], start=layer_idx):
40
  layer_type, layer_info = next(iter(layer_spec.items()))
41
  layer_args = layer_info.get("args", {})
@@ -126,7 +126,7 @@ class YOLO(nn.Module):
126
  weights: A OrderedDict containing the new weights.
127
  """
128
  if isinstance(weights, Path):
129
- weights = torch.load(weights, map_location=torch.device("cpu"))
130
  if "model_state_dict" in weights:
131
  weights = weights["model_state_dict"]
132
 
@@ -147,7 +147,7 @@ class YOLO(nn.Module):
147
 
148
  for error_name, error_set in error_dict.items():
149
  for weight_name in error_set:
150
- logger.warning(f"⚠️ Weight {error_name} for key: {'.'.join(weight_name)}")
151
 
152
  self.model.load_state_dict(model_state_dict)
153
 
@@ -174,7 +174,7 @@ def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True,
174
  prepare_weight(weight_path=weight_path)
175
  if weight_path.exists():
176
  model.save_load_weights(weight_path)
177
- logger.info("βœ… Success load model & weight")
178
  else:
179
- logger.info("βœ… Success load model")
180
  return model
 
3
  from typing import Dict, List, Union
4
 
5
  import torch
 
6
  from omegaconf import ListConfig, OmegaConf
7
  from torch import nn
8
 
9
  from yolo.config.config import ModelConfig, YOLOLayer
10
  from yolo.tools.dataset_preparation import prepare_weight
11
+ from yolo.utils.logger import logger
12
  from yolo.utils.module_utils import get_layer_map
13
 
14
 
 
32
  def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
33
  self.layer_index = {}
34
  output_dim, layer_idx = [3], 1
35
+ logger.info(f":tractor: Building YOLO")
36
  for arch_name in model_arch:
37
  if model_arch[arch_name]:
38
+ logger.info(f" :building_construction: Building {arch_name}")
39
  for layer_idx, layer_spec in enumerate(model_arch[arch_name], start=layer_idx):
40
  layer_type, layer_info = next(iter(layer_spec.items()))
41
  layer_args = layer_info.get("args", {})
 
126
  weights: A OrderedDict containing the new weights.
127
  """
128
  if isinstance(weights, Path):
129
+ weights = torch.load(weights, map_location=torch.device("cpu"), weights_only=False)
130
  if "model_state_dict" in weights:
131
  weights = weights["model_state_dict"]
132
 
 
147
 
148
  for error_name, error_set in error_dict.items():
149
  for weight_name in error_set:
150
+ logger.warning(f":warning: Weight {error_name} for key: {'.'.join(weight_name)}")
151
 
152
  self.model.load_state_dict(model_state_dict)
153
 
 
174
  prepare_weight(weight_path=weight_path)
175
  if weight_path.exists():
176
  model.save_load_weights(weight_path)
177
+ logger.info(":white_check_mark: Success load model & weight")
178
  else:
179
+ logger.info(":white_check_mark: Success load model")
180
  return model
yolo/tools/data_loader.py CHANGED
@@ -5,12 +5,10 @@ from typing import Generator, List, Tuple, Union
5
 
6
  import numpy as np
7
  import torch
8
- from loguru import logger
9
  from PIL import Image
10
  from rich.progress import track
11
  from torch import Tensor
12
  from torch.utils.data import DataLoader, Dataset
13
- from torch.utils.data.distributed import DistributedSampler
14
 
15
  from yolo.config.config import DataConfig, DatasetConfig
16
  from yolo.tools.data_augmentation import *
@@ -20,7 +18,9 @@ from yolo.utils.dataset_utils import (
20
  create_image_metadata,
21
  locate_label_paths,
22
  scale_segmentation,
 
23
  )
 
24
 
25
 
26
  class YoloDataset(Dataset):
@@ -32,7 +32,8 @@ class YoloDataset(Dataset):
32
  transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
33
  self.transform = AugmentationComposer(transforms, self.image_size)
34
  self.transform.get_more_data = self.get_more_data
35
- self.data = self.load_data(Path(dataset_cfg.path), phase_name)
 
36
 
37
  def load_data(self, dataset_path: Path, phase_name: str):
38
  """
@@ -48,12 +49,12 @@ class YoloDataset(Dataset):
48
  cache_path = dataset_path / f"{phase_name}.cache"
49
 
50
  if not cache_path.exists():
51
- logger.info("🏭 Generating {} cache", phase_name)
52
  data = self.filter_data(dataset_path, phase_name)
53
  torch.save(data, cache_path)
54
  else:
55
  data = torch.load(cache_path, weights_only=False)
56
- logger.info("πŸ“¦ Loaded {} cache", phase_name)
57
  return data
58
 
59
  def filter_data(self, dataset_path: Path, phase_name: str) -> list:
@@ -103,7 +104,7 @@ class YoloDataset(Dataset):
103
  img_path = images_path / image_name
104
  data.append((img_path, labels))
105
  valid_inputs += 1
106
- logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
107
  return data
108
 
109
  def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[Tensor, None]:
@@ -132,9 +133,11 @@ class YoloDataset(Dataset):
132
  return torch.zeros((0, 5))
133
 
134
  def get_data(self, idx):
135
- img_path, bboxes = self.data[idx]
136
- img = Image.open(img_path).convert("RGB")
137
- return img, bboxes, img_path
 
 
138
 
139
  def get_more_data(self, num: int = 1):
140
  indices = torch.randint(0, len(self), (num,))
@@ -143,67 +146,59 @@ class YoloDataset(Dataset):
143
  def __getitem__(self, idx) -> Tuple[Image.Image, Tensor, Tensor, List[str]]:
144
  img, bboxes, img_path = self.get_data(idx)
145
  img, bboxes, rev_tensor = self.transform(img, bboxes)
 
 
146
  return img, bboxes, rev_tensor, img_path
147
 
148
  def __len__(self) -> int:
149
- return len(self.data)
150
-
151
-
152
- class YoloDataLoader(DataLoader):
153
- def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train", use_ddp: bool = False):
154
- """Initializes the YoloDataLoader with hydra-config files."""
155
- dataset = YoloDataset(data_cfg, dataset_cfg, task)
156
- sampler = DistributedSampler(dataset, shuffle=data_cfg.shuffle) if use_ddp else None
157
- self.image_size = data_cfg.image_size[0]
158
- super().__init__(
159
- dataset,
160
- batch_size=data_cfg.batch_size,
161
- sampler=sampler,
162
- shuffle=data_cfg.shuffle and not use_ddp,
163
- num_workers=data_cfg.cpu_num,
164
- pin_memory=data_cfg.pin_memory,
165
- collate_fn=self.collate_fn,
166
- )
167
-
168
- def collate_fn(self, batch: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tensor]]:
169
- """
170
- A collate function to handle batching of images and their corresponding targets.
171
 
172
- Args:
173
- batch (list of tuples): Each tuple contains:
174
- - image (Tensor): The image tensor.
175
- - labels (Tensor): The tensor of labels for the image.
176
 
177
- Returns:
178
- Tuple[Tensor, List[Tensor]]: A tuple containing:
179
- - A tensor of batched images.
180
- - A list of tensors, each corresponding to bboxes for each image in the batch.
181
- """
182
- batch_size = len(batch)
183
- target_sizes = [item[1].size(0) for item in batch]
184
- # TODO: Improve readability of these proccess
185
- # TODO: remove maxBbox or reduce loss function memory usage
186
- batch_targets = torch.zeros(batch_size, min(max(target_sizes), 100), 5)
187
- batch_targets[:, :, 0] = -1
188
- for idx, target_size in enumerate(target_sizes):
189
- batch_targets[idx, : min(target_size, 100)] = batch[idx][1][:100]
190
- batch_targets[:, :, 1:] *= self.image_size
191
 
192
- batch_images, _, batch_reverse, batch_path = zip(*batch)
193
- batch_images = torch.stack(batch_images)
194
- batch_reverse = torch.stack(batch_reverse)
 
195
 
196
- return batch_size, batch_images, batch_targets, batch_reverse, batch_path
 
 
 
 
 
 
 
 
 
 
 
 
197
 
 
 
 
198
 
199
- def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train", use_ddp: bool = False):
 
 
 
200
  if task == "inference":
201
  return StreamDataLoader(data_cfg)
202
 
203
  if dataset_cfg.auto_download:
204
  prepare_dataset(dataset_cfg, task)
205
-
206
- return YoloDataLoader(data_cfg, dataset_cfg, task, use_ddp)
 
 
 
 
 
 
 
207
 
208
 
209
  class StreamDataLoader:
 
5
 
6
  import numpy as np
7
  import torch
 
8
  from PIL import Image
9
  from rich.progress import track
10
  from torch import Tensor
11
  from torch.utils.data import DataLoader, Dataset
 
12
 
13
  from yolo.config.config import DataConfig, DatasetConfig
14
  from yolo.tools.data_augmentation import *
 
18
  create_image_metadata,
19
  locate_label_paths,
20
  scale_segmentation,
21
+ tensorlize,
22
  )
23
+ from yolo.utils.logger import logger
24
 
25
 
26
  class YoloDataset(Dataset):
 
32
  transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
33
  self.transform = AugmentationComposer(transforms, self.image_size)
34
  self.transform.get_more_data = self.get_more_data
35
+ img_paths, bboxes = tensorlize(self.load_data(Path(dataset_cfg.path), phase_name))
36
+ self.img_paths, self.bboxes = img_paths, bboxes
37
 
38
  def load_data(self, dataset_path: Path, phase_name: str):
39
  """
 
49
  cache_path = dataset_path / f"{phase_name}.cache"
50
 
51
  if not cache_path.exists():
52
+ logger.info(f":factory: Generating {phase_name} cache")
53
  data = self.filter_data(dataset_path, phase_name)
54
  torch.save(data, cache_path)
55
  else:
56
  data = torch.load(cache_path, weights_only=False)
57
+ logger.info(f":package: Loaded {phase_name} cache")
58
  return data
59
 
60
  def filter_data(self, dataset_path: Path, phase_name: str) -> list:
 
104
  img_path = images_path / image_name
105
  data.append((img_path, labels))
106
  valid_inputs += 1
107
+ logger.info(f"Recorded {valid_inputs}/{len(images_list)} valid inputs")
108
  return data
109
 
110
  def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[Tensor, None]:
 
133
  return torch.zeros((0, 5))
134
 
135
  def get_data(self, idx):
136
+ img_path, bboxes = self.img_paths[idx], self.bboxes[idx]
137
+ valid_mask = bboxes[:, 0] != -1
138
+ with Image.open(img_path) as img:
139
+ img = img.convert("RGB")
140
+ return img, torch.from_numpy(bboxes[valid_mask]), img_path
141
 
142
  def get_more_data(self, num: int = 1):
143
  indices = torch.randint(0, len(self), (num,))
 
146
  def __getitem__(self, idx) -> Tuple[Image.Image, Tensor, Tensor, List[str]]:
147
  img, bboxes, img_path = self.get_data(idx)
148
  img, bboxes, rev_tensor = self.transform(img, bboxes)
149
+ bboxes[:, [1, 3]] *= self.image_size[0]
150
+ bboxes[:, [2, 4]] *= self.image_size[1]
151
  return img, bboxes, rev_tensor, img_path
152
 
153
  def __len__(self) -> int:
154
+ return len(self.bboxes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
 
 
 
 
156
 
157
+ def collate_fn(batch: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tensor]]:
158
+ """
159
+ A collate function to handle batching of images and their corresponding targets.
 
 
 
 
 
 
 
 
 
 
 
160
 
161
+ Args:
162
+ batch (list of tuples): Each tuple contains:
163
+ - image (Tensor): The image tensor.
164
+ - labels (Tensor): The tensor of labels for the image.
165
 
166
+ Returns:
167
+ Tuple[Tensor, List[Tensor]]: A tuple containing:
168
+ - A tensor of batched images.
169
+ - A list of tensors, each corresponding to bboxes for each image in the batch.
170
+ """
171
+ batch_size = len(batch)
172
+ target_sizes = [item[1].size(0) for item in batch]
173
+ # TODO: Improve readability of these process
174
+ # TODO: remove maxBbox or reduce loss function memory usage
175
+ batch_targets = torch.zeros(batch_size, min(max(target_sizes), 100), 5)
176
+ batch_targets[:, :, 0] = -1
177
+ for idx, target_size in enumerate(target_sizes):
178
+ batch_targets[idx, : min(target_size, 100)] = batch[idx][1][:100]
179
 
180
+ batch_images, _, batch_reverse, batch_path = zip(*batch)
181
+ batch_images = torch.stack(batch_images)
182
+ batch_reverse = torch.stack(batch_reverse)
183
 
184
+ return batch_size, batch_images, batch_targets, batch_reverse, batch_path
185
+
186
+
187
+ def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train"):
188
  if task == "inference":
189
  return StreamDataLoader(data_cfg)
190
 
191
  if dataset_cfg.auto_download:
192
  prepare_dataset(dataset_cfg, task)
193
+ dataset = YoloDataset(data_cfg, dataset_cfg, task)
194
+
195
+ return DataLoader(
196
+ dataset,
197
+ batch_size=data_cfg.batch_size,
198
+ num_workers=data_cfg.cpu_num,
199
+ pin_memory=data_cfg.pin_memory,
200
+ collate_fn=collate_fn,
201
+ )
202
 
203
 
204
  class StreamDataLoader:
yolo/tools/dataset_preparation.py CHANGED
@@ -3,10 +3,10 @@ from pathlib import Path
3
  from typing import Optional
4
 
5
  import requests
6
- from loguru import logger
7
  from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
8
 
9
  from yolo.config.config import DatasetConfig
 
10
 
11
 
12
  def download_file(url, destination: Path):
@@ -30,7 +30,7 @@ def download_file(url, destination: Path):
30
  for data in response.iter_content(chunk_size=1024 * 1024): # 1 MB chunks
31
  file.write(data)
32
  progress.update(task, advance=len(data))
33
- logger.info("βœ… Download completed.")
34
 
35
 
36
  def unzip_file(source: Path, destination: Path):
@@ -71,7 +71,7 @@ def prepare_dataset(dataset_cfg: DatasetConfig, task: str):
71
 
72
  final_place.mkdir(parents=True, exist_ok=True)
73
  if check_files(final_place, dataset_args.get("file_num")):
74
- logger.info(f"βœ… Dataset {dataset_type: <12} already verified.")
75
  continue
76
 
77
  if not local_zip_path.exists():
 
3
  from typing import Optional
4
 
5
  import requests
 
6
  from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
7
 
8
  from yolo.config.config import DatasetConfig
9
+ from yolo.utils.logger import logger
10
 
11
 
12
  def download_file(url, destination: Path):
 
30
  for data in response.iter_content(chunk_size=1024 * 1024): # 1 MB chunks
31
  file.write(data)
32
  progress.update(task, advance=len(data))
33
+ logger.info(":white_check_mark: Download completed.")
34
 
35
 
36
  def unzip_file(source: Path, destination: Path):
 
71
 
72
  final_place.mkdir(parents=True, exist_ok=True)
73
  if check_files(final_place, dataset_args.get("file_num")):
74
+ logger.info(f":white_check_mark: Dataset {dataset_type: <12} already verified.")
75
  continue
76
 
77
  if not local_zip_path.exists():
yolo/tools/drawer.py CHANGED
@@ -3,12 +3,12 @@ from typing import List, Optional, Union
3
 
4
  import numpy as np
5
  import torch
6
- from loguru import logger
7
  from PIL import Image, ImageDraw, ImageFont
8
  from torchvision.transforms.functional import to_pil_image
9
 
10
  from yolo.config.config import ModelConfig
11
  from yolo.model.yolo import YOLO
 
12
 
13
 
14
  def draw_bboxes(
@@ -121,6 +121,6 @@ def draw_model(*, model_cfg: ModelConfig = None, model: YOLO = None, v7_base=Fal
121
  dot.edge(str(idx), str(jdx))
122
  try:
123
  dot.render("Model-arch", format="png", cleanup=True)
124
- logger.info("🎨 Drawing Model Architecture at Model-arch.png")
125
  except:
126
- logger.warning("⚠️ Could not find graphviz backend, continue without drawing the model architecture")
 
3
 
4
  import numpy as np
5
  import torch
 
6
  from PIL import Image, ImageDraw, ImageFont
7
  from torchvision.transforms.functional import to_pil_image
8
 
9
  from yolo.config.config import ModelConfig
10
  from yolo.model.yolo import YOLO
11
+ from yolo.utils.logger import logger
12
 
13
 
14
  def draw_bboxes(
 
121
  dot.edge(str(idx), str(jdx))
122
  try:
123
  dot.render("Model-arch", format="png", cleanup=True)
124
+ logger.info(":artist_palette: Drawing Model Architecture at Model-arch.png")
125
  except:
126
+ logger.warning(":warning: Could not find graphviz backend, continue without drawing the model architecture")
yolo/tools/loss_functions.py CHANGED
@@ -2,12 +2,12 @@ from typing import Any, Dict, List, Tuple
2
 
3
  import torch
4
  import torch.nn.functional as F
5
- from loguru import logger
6
  from torch import Tensor, nn
7
  from torch.nn import BCEWithLogitsLoss
8
 
9
  from yolo.config.config import Config, LossConfig
10
  from yolo.utils.bounding_box_utils import BoxMatcher, Vec2Box, calculate_iou
 
11
 
12
 
13
  class BCELoss(nn.Module):
@@ -119,22 +119,24 @@ class DualLoss:
119
 
120
  def __call__(
121
  self, aux_predicts: List[Tensor], main_predicts: List[Tensor], targets: Tensor
122
- ) -> Tuple[Tensor, Dict[str, Tensor]]:
123
  # TODO: Need Refactor this region, make it flexible!
124
  aux_iou, aux_dfl, aux_cls = self.loss(aux_predicts, targets)
125
  main_iou, main_dfl, main_cls = self.loss(main_predicts, targets)
126
 
 
 
 
 
 
127
  loss_dict = {
128
- "BoxLoss": self.iou_rate * (aux_iou * self.aux_rate + main_iou),
129
- "DFLoss": self.dfl_rate * (aux_dfl * self.aux_rate + main_dfl),
130
- "BCELoss": self.cls_rate * (aux_cls * self.aux_rate + main_cls),
131
  }
132
- loss_sum = sum(list(loss_dict.values())) / len(loss_dict)
133
- return loss_sum, loss_dict
134
 
135
 
136
  def create_loss_function(cfg: Config, vec2box) -> DualLoss:
137
  # TODO: make it flexible, if cfg doesn't contain aux, only use SingleLoss
138
  loss_function = DualLoss(cfg, vec2box)
139
- logger.info("βœ… Success load loss function")
140
  return loss_function
 
2
 
3
  import torch
4
  import torch.nn.functional as F
 
5
  from torch import Tensor, nn
6
  from torch.nn import BCEWithLogitsLoss
7
 
8
  from yolo.config.config import Config, LossConfig
9
  from yolo.utils.bounding_box_utils import BoxMatcher, Vec2Box, calculate_iou
10
+ from yolo.utils.logger import logger
11
 
12
 
13
  class BCELoss(nn.Module):
 
119
 
120
  def __call__(
121
  self, aux_predicts: List[Tensor], main_predicts: List[Tensor], targets: Tensor
122
+ ) -> Tuple[Tensor, Dict[str, float]]:
123
  # TODO: Need Refactor this region, make it flexible!
124
  aux_iou, aux_dfl, aux_cls = self.loss(aux_predicts, targets)
125
  main_iou, main_dfl, main_cls = self.loss(main_predicts, targets)
126
 
127
+ total_loss = [
128
+ self.iou_rate * (aux_iou * self.aux_rate + main_iou),
129
+ self.dfl_rate * (aux_dfl * self.aux_rate + main_dfl),
130
+ self.cls_rate * (aux_cls * self.aux_rate + main_cls),
131
+ ]
132
  loss_dict = {
133
+ f"Loss/{name}Loss": value.detach().item() for name, value in zip(["Box", "DFL", "BCE"], total_loss)
 
 
134
  }
135
+ return sum(total_loss), loss_dict
 
136
 
137
 
138
  def create_loss_function(cfg: Config, vec2box) -> DualLoss:
139
  # TODO: make it flexible, if cfg doesn't contain aux, only use SingleLoss
140
  loss_function = DualLoss(cfg, vec2box)
141
+ logger.info(":white_check_mark: Success load loss function")
142
  return loss_function
yolo/tools/solver.py CHANGED
@@ -1,267 +1,142 @@
1
- import contextlib
2
- import io
3
- import json
4
- import os
5
- import time
6
- from collections import defaultdict
7
  from pathlib import Path
8
- from typing import Dict, Optional
9
 
10
- import torch
11
- from loguru import logger
12
- from pycocotools.coco import COCO
13
- from torch import Tensor, distributed
14
- from torch.cuda.amp import GradScaler, autocast
15
- from torch.nn.parallel import DistributedDataParallel as DDP
16
- from torch.utils.data import DataLoader
17
 
18
- from yolo.config.config import Config, DatasetConfig, TrainConfig, ValidationConfig
19
- from yolo.model.yolo import YOLO
20
- from yolo.tools.data_loader import StreamDataLoader, create_dataloader
21
- from yolo.tools.drawer import draw_bboxes, draw_model
22
  from yolo.tools.loss_functions import create_loss_function
23
- from yolo.utils.bounding_box_utils import Vec2Box, calculate_map
24
- from yolo.utils.dataset_utils import locate_label_paths
25
- from yolo.utils.logging_utils import ProgressLogger, log_model_structure
26
- from yolo.utils.model_utils import (
27
- ExponentialMovingAverage,
28
- PostProccess,
29
- collect_prediction,
30
- create_optimizer,
31
- create_scheduler,
32
- predicts_to_json,
33
- )
34
- from yolo.utils.solver_utils import calculate_ap
35
 
36
 
37
- class ModelTrainer:
38
- def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device, use_ddp: bool):
39
- train_cfg: TrainConfig = cfg.task
40
- self.model = model if not use_ddp else DDP(model, device_ids=[device])
41
- self.use_ddp = use_ddp
42
- self.vec2box = vec2box
43
- self.device = device
44
- self.optimizer = create_optimizer(model, train_cfg.optimizer)
45
- self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
46
- self.loss_fn = create_loss_function(cfg, vec2box)
47
- self.progress = progress
48
- self.num_epochs = cfg.task.epoch
49
- self.mAPs_dict = defaultdict(list)
50
 
51
- self.weights_dir = self.progress.save_path / "weights"
52
- self.weights_dir.mkdir(exist_ok=True)
53
 
54
- if not progress.quite_mode:
55
- log_model_structure(model.model)
56
- draw_model(model=model)
57
 
58
- self.validation_dataloader = create_dataloader(
59
- cfg.task.validation.data, cfg.dataset, cfg.task.validation.task, use_ddp
60
- )
61
- self.validator = ModelValidator(cfg.task.validation, cfg.dataset, model, vec2box, progress, device)
62
-
63
- if getattr(train_cfg.ema, "enabled", False):
64
- self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
65
  else:
66
- self.ema = None
67
- self.scaler = GradScaler()
68
-
69
- def train_one_batch(self, images: Tensor, targets: Tensor):
70
- images, targets = images.to(self.device), targets.to(self.device)
71
- self.optimizer.zero_grad()
72
-
73
- with autocast():
74
- predicts = self.model(images)
75
- aux_predicts = self.vec2box(predicts["AUX"])
76
- main_predicts = self.vec2box(predicts["Main"])
77
- loss, loss_item = self.loss_fn(aux_predicts, main_predicts, targets)
78
-
79
- self.scaler.scale(loss).backward()
80
- self.scaler.unscale_(self.optimizer)
81
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
82
- self.scaler.step(self.optimizer)
83
- self.scaler.update()
84
-
85
- return loss_item
86
-
87
- def train_one_epoch(self, dataloader):
88
- self.model.train()
89
- total_loss = defaultdict(lambda: torch.tensor(0.0, device=self.device))
90
- total_samples = 0
91
- self.optimizer.next_epoch(len(dataloader))
92
- for batch_size, images, targets, *_ in dataloader:
93
- self.optimizer.next_batch()
94
- loss_each = self.train_one_batch(images, targets)
95
-
96
- for loss_name, loss_val in loss_each.items():
97
- if self.use_ddp: # collecting loss for each batch
98
- distributed.all_reduce(loss_val, op=distributed.ReduceOp.AVG)
99
- total_loss[loss_name] += loss_val * batch_size
100
- total_samples += batch_size
101
- self.progress.one_batch(loss_each)
102
-
103
- for loss_val in total_loss.values():
104
- loss_val /= total_samples
105
-
106
- if self.scheduler:
107
- self.scheduler.step()
108
-
109
- return total_loss
110
-
111
- def save_checkpoint(self, epoch_idx: int, file_name: Optional[str] = None):
112
- file_name = file_name or f"E{epoch_idx:03d}.pt"
113
- file_path = self.weights_dir / file_name
114
-
115
- checkpoint = {
116
- "epoch": epoch_idx,
117
- "model_state_dict": self.model.state_dict(),
118
- "optimizer_state_dict": self.optimizer.state_dict(),
119
- }
120
- if self.ema:
121
- self.ema.apply_shadow()
122
- checkpoint["model_state_dict_ema"] = self.model.state_dict()
123
- self.ema.restore()
124
-
125
- logger.info(f"πŸ’Ύ success save at {file_path}")
126
- torch.save(checkpoint, file_path)
127
-
128
- def good_epoch(self, mAPs: Dict[str, Tensor]) -> bool:
129
- save_flag = True
130
- for mAP_key, mAP_val in mAPs.items():
131
- self.mAPs_dict[mAP_key].append(mAP_val)
132
- if mAP_val < max(self.mAPs_dict[mAP_key]):
133
- save_flag = False
134
- return save_flag
135
-
136
- def solve(self, dataloader: DataLoader):
137
- logger.info("πŸš„ Start Training!")
138
- num_epochs = self.num_epochs
139
-
140
- self.progress.start_train(num_epochs)
141
- for epoch_idx in range(num_epochs):
142
- if self.use_ddp:
143
- dataloader.sampler.set_epoch(epoch_idx)
144
-
145
- self.progress.start_one_epoch(len(dataloader), "Train", self.optimizer, epoch_idx)
146
- epoch_loss = self.train_one_epoch(dataloader)
147
- self.progress.finish_one_epoch(epoch_loss, epoch_idx=epoch_idx)
148
-
149
- mAPs = self.validator.solve(self.validation_dataloader, epoch_idx=epoch_idx)
150
- if mAPs is not None and self.good_epoch(mAPs):
151
- self.save_checkpoint(epoch_idx=epoch_idx)
152
- # TODO: save model if result are better than before
153
- self.progress.finish_train()
154
-
155
-
156
- class ModelTester:
157
- def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device):
158
- self.model = model
159
- self.device = device
160
- self.progress = progress
161
-
162
- self.post_proccess = PostProccess(vec2box, cfg.task.nms)
163
- self.save_path = progress.save_path / "images"
164
- os.makedirs(self.save_path, exist_ok=True)
165
- self.save_predict = getattr(cfg.task, "save_predict", None)
166
- self.idx2label = cfg.dataset.class_list
167
-
168
- def solve(self, dataloader: StreamDataLoader):
169
- logger.info("πŸ‘€ Start Inference!")
170
- if isinstance(self.model, torch.nn.Module):
171
- self.model.eval()
172
-
173
- if dataloader.is_stream:
174
- import cv2
175
- import numpy as np
176
-
177
- last_time = time.time()
178
- try:
179
- for idx, (images, rev_tensor, origin_frame) in enumerate(dataloader):
180
- images = images.to(self.device)
181
- rev_tensor = rev_tensor.to(self.device)
182
- with torch.no_grad():
183
- predicts = self.model(images)
184
- predicts = self.post_proccess(predicts, rev_tensor)
185
- img = draw_bboxes(origin_frame, predicts, idx2label=self.idx2label)
186
-
187
- if dataloader.is_stream:
188
- img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
189
- fps = 1 / (time.time() - last_time)
190
- cv2.putText(img, f"FPS: {fps:.2f}", (0, 15), 0, 0.5, (100, 255, 0), 1, cv2.LINE_AA)
191
- last_time = time.time()
192
- cv2.imshow("Prediction", img)
193
- if cv2.waitKey(1) & 0xFF == ord("q"):
194
- break
195
- if not self.save_predict:
196
- continue
197
- if self.save_predict != False:
198
- save_image_path = self.save_path / f"frame{idx:03d}.png"
199
- img.save(save_image_path)
200
- logger.info(f"πŸ’Ύ Saved visualize image at {save_image_path}")
201
-
202
- except (KeyboardInterrupt, Exception) as e:
203
- dataloader.stop_event.set()
204
- dataloader.stop()
205
- if isinstance(e, KeyboardInterrupt):
206
- logger.error("User Keyboard Interrupt")
207
- else:
208
- raise e
209
- dataloader.stop()
210
 
 
 
211
 
212
- class ModelValidator:
213
- def __init__(
214
- self,
215
- validation_cfg: ValidationConfig,
216
- dataset_cfg: DatasetConfig,
217
- model: YOLO,
218
- vec2box: Vec2Box,
219
- progress: ProgressLogger,
220
- device,
221
- ):
222
- self.model = model
223
- self.device = device
224
- self.progress = progress
225
 
226
- self.post_proccess = PostProccess(vec2box, validation_cfg.nms)
227
- self.json_path = self.progress.save_path / "predict.json"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
- with contextlib.redirect_stdout(io.StringIO()):
230
- # TODO: load with config file
231
- json_path, _ = locate_label_paths(Path(dataset_cfg.path), dataset_cfg.get("validation", "val"))
232
- if json_path:
233
- self.coco_gt = COCO(json_path)
234
 
235
- def solve(self, dataloader, epoch_idx=1):
236
- # logger.info("πŸ§ͺ Start Validation!")
237
- self.model.eval()
238
- predict_json, mAPs = [], defaultdict(list)
239
- self.progress.start_one_epoch(len(dataloader), task="Validate")
240
- for batch_size, images, targets, rev_tensor, img_paths in dataloader:
241
- images, targets, rev_tensor = images.to(self.device), targets.to(self.device), rev_tensor.to(self.device)
242
- with torch.no_grad():
243
- predicts = self.model(images)
244
- predicts = self.post_proccess(predicts)
245
- for idx, predict in enumerate(predicts):
246
- mAP = calculate_map(predict, targets[idx])
247
- for mAP_key, mAP_val in mAP.items():
248
- mAPs[mAP_key].append(mAP_val)
249
 
250
- avg_mAPs = {key: 100 * torch.mean(torch.stack(val)) for key, val in mAPs.items()}
251
- self.progress.one_batch(avg_mAPs)
 
 
 
 
252
 
253
- predict_json.extend(predicts_to_json(img_paths, predicts, rev_tensor))
254
- self.progress.finish_one_epoch(avg_mAPs, epoch_idx=epoch_idx)
255
- self.progress.visualize_image(images, targets, predicts, epoch_idx=epoch_idx)
 
 
256
 
257
- with open(self.json_path, "w") as f:
258
- predict_json = collect_prediction(predict_json, self.progress.local_rank)
259
- if self.progress.local_rank != 0:
260
- return
261
- json.dump(predict_json, f)
262
- if hasattr(self, "coco_gt"):
263
- self.progress.start_pycocotools()
264
- result = calculate_ap(self.coco_gt, predict_json)
265
- self.progress.finish_pycocotools(result, epoch_idx)
266
 
267
- return avg_mAPs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from pathlib import Path
 
2
 
3
+ from lightning import LightningModule
4
+ from torchmetrics.detection import MeanAveragePrecision
 
 
 
 
 
5
 
6
+ from yolo.config.config import Config
7
+ from yolo.model.yolo import create_model
8
+ from yolo.tools.data_loader import create_dataloader
9
+ from yolo.tools.drawer import draw_bboxes
10
  from yolo.tools.loss_functions import create_loss_function
11
+ from yolo.utils.bounding_box_utils import create_converter, to_metrics_format
12
+ from yolo.utils.model_utils import PostProcess, create_optimizer, create_scheduler
 
 
 
 
 
 
 
 
 
 
13
 
14
 
15
+ class BaseModel(LightningModule):
16
+ def __init__(self, cfg: Config):
17
+ super().__init__()
18
+ self.model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight)
 
 
 
 
 
 
 
 
 
19
 
20
+ def forward(self, x):
21
+ return self.model(x)
22
 
 
 
 
23
 
24
+ class ValidateModel(BaseModel):
25
+ def __init__(self, cfg: Config):
26
+ super().__init__(cfg)
27
+ self.cfg = cfg
28
+ if self.cfg.task.task == "validation":
29
+ self.validation_cfg = self.cfg.task
 
30
  else:
31
+ self.validation_cfg = self.cfg.task.validation
32
+ self.metric = MeanAveragePrecision(iou_type="bbox", box_format="xyxy")
33
+ self.metric.warn_on_many_detections = False
34
+ self.val_loader = create_dataloader(self.validation_cfg.data, self.cfg.dataset, self.validation_cfg.task)
35
+
36
+ def setup(self, stage):
37
+ self.vec2box = create_converter(
38
+ self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device
39
+ )
40
+ self.post_process = PostProcess(self.vec2box, self.validation_cfg.nms)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ def val_dataloader(self):
43
+ return self.val_loader
44
 
45
+ def validation_step(self, batch, batch_idx):
46
+ batch_size, images, targets, rev_tensor, img_paths = batch
47
+ predicts = self.post_process(self(images))
48
+ batch_metrics = self.metric(
49
+ [to_metrics_format(predict) for predict in predicts], [to_metrics_format(target) for target in targets]
50
+ )
 
 
 
 
 
 
 
51
 
52
+ self.log_dict(
53
+ {
54
+ "map": batch_metrics["map"],
55
+ "map_50": batch_metrics["map_50"],
56
+ },
57
+ on_step=True,
58
+ batch_size=batch_size,
59
+ )
60
+ return predicts
61
+
62
+ def on_validation_epoch_end(self):
63
+ epoch_metrics = self.metric.compute()
64
+ del epoch_metrics["classes"]
65
+ self.log_dict(epoch_metrics, prog_bar=True, rank_zero_only=True)
66
+ self.log_dict(
67
+ {"PyCOCO/AP @ .5:.95": epoch_metrics["map"], "PyCOCO/AP @ .5": epoch_metrics["map_50"]}, rank_zero_only=True
68
+ )
69
+ self.metric.reset()
70
+
71
+
72
+ class TrainModel(ValidateModel):
73
+ def __init__(self, cfg: Config):
74
+ super().__init__(cfg)
75
+ self.cfg = cfg
76
+ self.train_loader = create_dataloader(self.cfg.task.data, self.cfg.dataset, self.cfg.task.task)
77
+
78
+ def setup(self, stage):
79
+ super().setup(stage)
80
+ self.loss_fn = create_loss_function(self.cfg, self.vec2box)
81
+
82
+ def train_dataloader(self):
83
+ return self.train_loader
84
+
85
+ def on_train_epoch_start(self):
86
+ self.trainer.optimizers[0].next_epoch(len(self.train_loader))
87
+
88
+ def training_step(self, batch, batch_idx):
89
+ lr_dict = self.trainer.optimizers[0].next_batch()
90
+ batch_size, images, targets, *_ = batch
91
+ predicts = self(images)
92
+ aux_predicts = self.vec2box(predicts["AUX"])
93
+ main_predicts = self.vec2box(predicts["Main"])
94
+ loss, loss_item = self.loss_fn(aux_predicts, main_predicts, targets)
95
+ self.log_dict(
96
+ loss_item,
97
+ prog_bar=True,
98
+ on_epoch=True,
99
+ batch_size=batch_size,
100
+ rank_zero_only=True,
101
+ )
102
+ self.log_dict(lr_dict, prog_bar=False, logger=True, on_epoch=False, rank_zero_only=True)
103
+ return loss * batch_size
104
 
105
+ def configure_optimizers(self):
106
+ optimizer = create_optimizer(self.model, self.cfg.task.optimizer)
107
+ scheduler = create_scheduler(optimizer, self.cfg.task.scheduler)
108
+ return [optimizer], [scheduler]
 
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+ class InferenceModel(BaseModel):
112
+ def __init__(self, cfg: Config):
113
+ super().__init__(cfg)
114
+ self.cfg = cfg
115
+ # TODO: Add FastModel
116
+ self.predict_loader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
117
 
118
+ def setup(self, stage):
119
+ self.vec2box = create_converter(
120
+ self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device
121
+ )
122
+ self.post_process = PostProcess(self.vec2box, self.cfg.task.nms)
123
 
124
+ def predict_dataloader(self):
125
+ return self.predict_loader
 
 
 
 
 
 
 
126
 
127
+ def predict_step(self, batch, batch_idx):
128
+ images, rev_tensor, origin_frame = batch
129
+ predicts = self.post_process(self(images), rev_tensor)
130
+ img = draw_bboxes(origin_frame, predicts, idx2label=self.cfg.dataset.class_list)
131
+ if getattr(self.predict_loader, "is_stream", None):
132
+ fps = self._display_stream(img)
133
+ else:
134
+ fps = None
135
+ if getattr(self.cfg.task, "save_predict", None):
136
+ self._save_image(img, batch_idx)
137
+ return img, fps
138
+
139
+ def _save_image(self, img, batch_idx):
140
+ save_image_path = Path(self.trainer.default_root_dir) / f"frame{batch_idx:03d}.png"
141
+ img.save(save_image_path)
142
+ print(f"πŸ’Ύ Saved visualize image at {save_image_path}")
yolo/utils/bounding_box_utils.py CHANGED
@@ -4,17 +4,18 @@ from typing import Dict, List, Optional, Tuple, Union
4
  import torch
5
  import torch.nn.functional as F
6
  from einops import rearrange
7
- from loguru import logger
8
- from torch import Tensor, arange, tensor
9
  from torchvision.ops import batched_nms
10
 
11
- from yolo.config.config import AnchorConfig, MatcherConfig, ModelConfig, NMSConfig
12
  from yolo.model.yolo import YOLO
 
13
 
14
 
15
  def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
16
  metrics = metrics.lower()
17
- EPS = 1e-9
18
  dtype = bbox1.dtype
19
  bbox1 = bbox1.to(torch.float32)
20
  bbox2 = bbox2.to(torch.float32)
@@ -69,7 +70,8 @@ def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
69
  (bbox2[..., 2] - bbox2[..., 0]) / (bbox2[..., 3] - bbox2[..., 1] + EPS)
70
  )
71
  v = (4 / (math.pi**2)) * (arctan**2)
72
- alpha = v / (v - iou + 1 + EPS)
 
73
  # Compute CIoU
74
  ciou = diou - alpha * v
75
  return ciou.to(dtype)
@@ -129,7 +131,10 @@ def generate_anchors(image_size: List[int], strides: List[int]):
129
  shift = stride // 2
130
  h = torch.arange(0, H, stride) + shift
131
  w = torch.arange(0, W, stride) + shift
132
- anchor_h, anchor_w = torch.meshgrid(h, w, indexing="ij")
 
 
 
133
  anchor = torch.stack([anchor_w.flatten(), anchor_h.flatten()], dim=-1)
134
  anchors.append(anchor)
135
  all_anchors = torch.cat(anchors, dim=0)
@@ -207,7 +212,7 @@ class BoxMatcher:
207
  topk_masks = topk_targets > 0
208
  return topk_targets, topk_masks
209
 
210
- def filter_duplicates(self, target_matrix: Tensor):
211
  """
212
  Filter the maximum suitability target index of each anchor.
213
 
@@ -217,17 +222,44 @@ class BoxMatcher:
217
  Returns:
218
  unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
219
  """
220
- # TODO: add a assert for no target on the image
221
- unique_indices = target_matrix.argmax(dim=1)
222
- return unique_indices[..., None]
 
 
223
 
224
  def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
225
- """
226
- 1. For each anchor prediction, find the highest suitability targets
227
- 2. Select the targets
228
- 2. Noramlize the class probilities of targets
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  """
230
  predict_cls, predict_bbox = predict
 
 
 
 
 
 
 
 
 
 
 
231
  target_cls, target_bbox = target.split([1, 4], dim=-1) # B x N x (C B) -> B x N x C, B x N x B
232
  target_cls = target_cls.long().clamp(0)
233
 
@@ -246,23 +278,22 @@ class BoxMatcher:
246
  topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
247
 
248
  # delete one anchor pred assign to mutliple gts
249
- unique_indices = self.filter_duplicates(topk_targets)
250
-
251
- # TODO: do we need grid_mask? Filter the valid groud truth
252
- valid_mask = (grid_mask.sum(dim=-2) * topk_mask.sum(dim=-2)).bool()
253
 
254
  align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
255
  align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
256
  align_cls = F.one_hot(align_cls, self.class_num)
257
 
258
  # normalize class ditribution
 
 
259
  max_target = target_matrix.amax(dim=-1, keepdim=True)
260
  max_iou = iou_mat.amax(dim=-1, keepdim=True)
261
  normalize_term = (target_matrix / (max_target + 1e-9)) * max_iou
262
  normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices)
263
  align_cls = align_cls * normalize_term * valid_mask[:, :, None]
264
-
265
- return torch.cat([align_cls, align_bbox], dim=-1), valid_mask.bool()
266
 
267
 
268
  class Vec2Box:
@@ -270,7 +301,7 @@ class Vec2Box:
270
  self.device = device
271
 
272
  if hasattr(anchor_cfg, "strides"):
273
- logger.info(f"🈢 Found stride of model {anchor_cfg.strides}")
274
  self.strides = anchor_cfg.strides
275
  else:
276
  logger.info("🧸 Found no stride of model, performed a dummy test for auto-anchor size")
@@ -314,7 +345,7 @@ class Anc2Box:
314
  self.device = device
315
 
316
  if hasattr(anchor_cfg, "strides"):
317
- logger.info(f"🈢 Found stride of model {anchor_cfg.strides}")
318
  self.strides = anchor_cfg.strides
319
  else:
320
  logger.info("🧸 Found no stride of model, performed a dummy test for auto-anchor size")
@@ -388,7 +419,7 @@ def bbox_nms(cls_dist: Tensor, bbox: Tensor, nms_cfg: NMSConfig, confidence: Opt
388
  valid_box = bbox[valid_mask.repeat(1, 1, 4)].view(-1, 4)
389
 
390
  batch_idx, *_ = torch.where(valid_mask)
391
- nms_idx = batched_nms(valid_box, valid_cls, batch_idx, nms_cfg.min_iou)
392
  predicts_nms = []
393
  for idx in range(cls_dist.size(0)):
394
  instance_idx = nms_idx[idx == batch_idx[nms_idx]]
@@ -401,48 +432,14 @@ def bbox_nms(cls_dist: Tensor, bbox: Tensor, nms_cfg: NMSConfig, confidence: Opt
401
  return predicts_nms
402
 
403
 
404
- def calculate_map(predictions, ground_truths, iou_thresholds=arange(0.5, 1, 0.05)) -> Dict[str, Tensor]:
405
- # TODO: Refactor this block, Flexible for calculate different mAP condition?
406
- device = predictions.device
407
- n_preds = predictions.size(0)
408
- n_gts = (ground_truths[:, 0] != -1).sum()
409
- ground_truths = ground_truths[:n_gts]
410
- aps = []
411
-
412
- ious = calculate_iou(predictions[:, 1:-1], ground_truths[:, 1:]) # [n_preds, n_gts]
413
-
414
- for threshold in iou_thresholds:
415
- tp = torch.zeros(n_preds, device=device, dtype=bool)
416
-
417
- max_iou, max_indices = ious.max(dim=1)
418
- above_threshold = max_iou >= threshold
419
- matched_classes = predictions[:, 0] == ground_truths[max_indices, 0]
420
- max_match = torch.zeros_like(ious)
421
- max_match[arange(n_preds), max_indices] = max_iou
422
- if max_match.size(0):
423
- tp[max_match.argmax(dim=0)] = True
424
- tp[~above_threshold | ~matched_classes] = False
425
-
426
- _, indices = torch.sort(predictions[:, 1], descending=True)
427
- tp = tp[indices]
428
-
429
- tp_cumsum = torch.cumsum(tp, dim=0)
430
- fp_cumsum = torch.cumsum(~tp, dim=0)
431
-
432
- precision = tp_cumsum / (tp_cumsum + fp_cumsum + 1e-6)
433
- recall = tp_cumsum / (n_gts + 1e-6)
434
-
435
- precision = torch.cat([torch.ones(1, device=device), precision, torch.zeros(1, device=device)])
436
- recall = torch.cat([torch.zeros(1, device=device), recall, torch.ones(1, device=device)])
437
-
438
- precision, _ = torch.cummax(precision.flip(0), dim=0)
439
- precision = precision.flip(0)
440
 
441
- ap = torch.trapezoid(precision, recall)
442
- aps.append(ap)
443
 
444
- mAP = {
445
- "mAP.5": aps[0],
446
- "mAP.5:.95": torch.mean(torch.stack(aps)),
447
- }
448
- return mAP
 
4
  import torch
5
  import torch.nn.functional as F
6
  from einops import rearrange
7
+ from torch import Tensor, tensor
8
+ from torchmetrics.detection import MeanAveragePrecision
9
  from torchvision.ops import batched_nms
10
 
11
+ from yolo.config.config import AnchorConfig, MatcherConfig, NMSConfig
12
  from yolo.model.yolo import YOLO
13
+ from yolo.utils.logger import logger
14
 
15
 
16
  def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
17
  metrics = metrics.lower()
18
+ EPS = 1e-7
19
  dtype = bbox1.dtype
20
  bbox1 = bbox1.to(torch.float32)
21
  bbox2 = bbox2.to(torch.float32)
 
70
  (bbox2[..., 2] - bbox2[..., 0]) / (bbox2[..., 3] - bbox2[..., 1] + EPS)
71
  )
72
  v = (4 / (math.pi**2)) * (arctan**2)
73
+ with torch.no_grad():
74
+ alpha = v / (v - iou + 1 + EPS)
75
  # Compute CIoU
76
  ciou = diou - alpha * v
77
  return ciou.to(dtype)
 
131
  shift = stride // 2
132
  h = torch.arange(0, H, stride) + shift
133
  w = torch.arange(0, W, stride) + shift
134
+ if torch.__version__ >= "2.3.0":
135
+ anchor_h, anchor_w = torch.meshgrid(h, w, indexing="ij")
136
+ else:
137
+ anchor_h, anchor_w = torch.meshgrid(h, w)
138
  anchor = torch.stack([anchor_w.flatten(), anchor_h.flatten()], dim=-1)
139
  anchors.append(anchor)
140
  all_anchors = torch.cat(anchors, dim=0)
 
212
  topk_masks = topk_targets > 0
213
  return topk_targets, topk_masks
214
 
215
+ def filter_duplicates(self, target_matrix: Tensor, topk_mask: Tensor):
216
  """
217
  Filter the maximum suitability target index of each anchor.
218
 
 
222
  Returns:
223
  unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
224
  """
225
+ duplicates = (topk_mask.sum(1, keepdim=True) > 1).repeat([1, topk_mask.size(1), 1])
226
+ max_idx = F.one_hot(target_matrix.argmax(1), topk_mask.size(1)).permute(0, 2, 1)
227
+ topk_mask = torch.where(duplicates, max_idx, topk_mask)
228
+ unique_indices = topk_mask.argmax(dim=1)
229
+ return unique_indices[..., None], topk_mask.sum(1), topk_mask
230
 
231
  def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
232
+ """Matches each target to the most suitable anchor.
233
+ 1. For each anchor prediction, find the highest suitability targets.
234
+ 2. Match target to the best anchor.
235
+ 3. Noramlize the class probilities of targets.
236
+
237
+ Args:
238
+ target: The ground truth class and bounding box information
239
+ as tensor of size [batch x targets x 5].
240
+ predict: Tuple of predicted class and bounding box tensors.
241
+ Class tensor is of size [batch x anchors x class]
242
+ Bounding box tensor is of size [batch x anchors x 4].
243
+
244
+ Returns:
245
+ anchor_matched_targets: Tensor of size [batch x anchors x (class + 4)].
246
+ A tensor assigning each target/gt to the best fitting anchor.
247
+ The class probabilities are normalized.
248
+ valid_mask: Bool tensor of shape [batch x anchors].
249
+ True if a anchor has a target/gt assigned to it.
250
  """
251
  predict_cls, predict_bbox = predict
252
+
253
+ # return if target has no gt information.
254
+ n_targets = target.shape[1]
255
+ if n_targets == 0:
256
+ device = predict_bbox.device
257
+ align_cls = torch.zeros_like(predict_cls, device=device)
258
+ align_bbox = torch.zeros_like(predict_bbox, device=device)
259
+ valid_mask = torch.zeros(predict_cls.shape[:2], dtype=bool, device=device)
260
+ anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1)
261
+ return anchor_matched_targets, valid_mask
262
+
263
  target_cls, target_bbox = target.split([1, 4], dim=-1) # B x N x (C B) -> B x N x C, B x N x B
264
  target_cls = target_cls.long().clamp(0)
265
 
 
278
  topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
279
 
280
  # delete one anchor pred assign to mutliple gts
281
+ unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask)
 
 
 
282
 
283
  align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
284
  align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
285
  align_cls = F.one_hot(align_cls, self.class_num)
286
 
287
  # normalize class ditribution
288
+ iou_mat *= topk_mask
289
+ target_matrix *= topk_mask
290
  max_target = target_matrix.amax(dim=-1, keepdim=True)
291
  max_iou = iou_mat.amax(dim=-1, keepdim=True)
292
  normalize_term = (target_matrix / (max_target + 1e-9)) * max_iou
293
  normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices)
294
  align_cls = align_cls * normalize_term * valid_mask[:, :, None]
295
+ anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1)
296
+ return anchor_matched_targets, valid_mask.bool()
297
 
298
 
299
  class Vec2Box:
 
301
  self.device = device
302
 
303
  if hasattr(anchor_cfg, "strides"):
304
+ logger.info(f":japanese_not_free_of_charge_button: Found stride of model {anchor_cfg.strides}")
305
  self.strides = anchor_cfg.strides
306
  else:
307
  logger.info("🧸 Found no stride of model, performed a dummy test for auto-anchor size")
 
345
  self.device = device
346
 
347
  if hasattr(anchor_cfg, "strides"):
348
+ logger.info(f":japanese_not_free_of_charge_button: Found stride of model {anchor_cfg.strides}")
349
  self.strides = anchor_cfg.strides
350
  else:
351
  logger.info("🧸 Found no stride of model, performed a dummy test for auto-anchor size")
 
419
  valid_box = bbox[valid_mask.repeat(1, 1, 4)].view(-1, 4)
420
 
421
  batch_idx, *_ = torch.where(valid_mask)
422
+ nms_idx = batched_nms(valid_box, valid_con, batch_idx, nms_cfg.min_iou)
423
  predicts_nms = []
424
  for idx in range(cls_dist.size(0)):
425
  instance_idx = nms_idx[idx == batch_idx[nms_idx]]
 
432
  return predicts_nms
433
 
434
 
435
+ def calculate_map(predictions, ground_truths) -> Dict[str, Tensor]:
436
+ metric = MeanAveragePrecision(iou_type="bbox", box_format="xyxy")
437
+ mAP = metric([to_metrics_format(predictions)], [to_metrics_format(ground_truths)])
438
+ return mAP
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
 
 
 
440
 
441
+ def to_metrics_format(prediction: Tensor) -> Dict[str, Union[float, Tensor]]:
442
+ bbox = {"boxes": prediction[:, 1:5], "labels": prediction[:, 0].int()}
443
+ if prediction.size(1) == 6:
444
+ bbox["scores"] = prediction[:, 5]
445
+ return bbox
yolo/utils/dataset_utils.py CHANGED
@@ -5,9 +5,10 @@ from pathlib import Path
5
  from typing import Any, Dict, List, Optional, Tuple
6
 
7
  import numpy as np
8
- from loguru import logger
9
 
10
  from yolo.tools.data_conversion import discretize_categories
 
11
 
12
 
13
  def locate_label_paths(dataset_path: Path, phase_name: Path) -> Tuple[Path, Path]:
@@ -111,3 +112,16 @@ def scale_segmentation(
111
  seg_array_with_cat.append(scaled_flat_seg_data)
112
 
113
  return seg_array_with_cat
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from typing import Any, Dict, List, Optional, Tuple
6
 
7
  import numpy as np
8
+ import torch
9
 
10
  from yolo.tools.data_conversion import discretize_categories
11
+ from yolo.utils.logger import logger
12
 
13
 
14
  def locate_label_paths(dataset_path: Path, phase_name: Path) -> Tuple[Path, Path]:
 
112
  seg_array_with_cat.append(scaled_flat_seg_data)
113
 
114
  return seg_array_with_cat
115
+
116
+
117
+ def tensorlize(data):
118
+ img_paths, bboxes = zip(*data)
119
+ max_box = max(bbox.size(0) for bbox in bboxes)
120
+ padded_bbox_list = []
121
+ for bbox in bboxes:
122
+ padding = torch.full((max_box, 5), -1, dtype=torch.float32)
123
+ padding[: bbox.size(0)] = bbox
124
+ padded_bbox_list.append(padding)
125
+ bboxes = np.stack(padded_bbox_list)
126
+ img_paths = np.array(img_paths)
127
+ return img_paths, bboxes
yolo/utils/deploy_utils.py CHANGED
@@ -1,11 +1,11 @@
1
  from pathlib import Path
2
 
3
  import torch
4
- from loguru import logger
5
  from torch import Tensor
6
 
7
  from yolo.config.config import Config
8
  from yolo.model.yolo import create_model
 
9
 
10
 
11
  class FastModelLoader:
@@ -21,10 +21,10 @@ class FastModelLoader:
21
 
22
  def _validate_compiler(self):
23
  if self.compiler not in ["onnx", "trt", "deploy"]:
24
- logger.warning(f"⚠️ Compiler '{self.compiler}' is not supported. Using original model.")
25
  self.compiler = None
26
  if self.cfg.device == "mps" and self.compiler == "trt":
27
- logger.warning("🍎 TensorRT does not support MPS devices. Using original model.")
28
  self.compiler = None
29
 
30
  def load_model(self, device):
@@ -59,7 +59,7 @@ class FastModelLoader:
59
  providers = ["CUDAExecutionProvider"]
60
  try:
61
  ort_session = InferenceSession(self.model_path, providers=providers)
62
- logger.info("πŸš€ Using ONNX as MODEL frameworks!")
63
  except Exception as e:
64
  logger.warning(f"🈳 Error loading ONNX model: {e}")
65
  ort_session = self._create_onnx_model(providers)
@@ -79,7 +79,7 @@ class FastModelLoader:
79
  output_names=["output"],
80
  dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
81
  )
82
- logger.info(f"πŸ“₯ ONNX model saved to {self.model_path}")
83
  return InferenceSession(self.model_path, providers=providers)
84
 
85
  def _load_trt_model(self):
@@ -88,7 +88,7 @@ class FastModelLoader:
88
  try:
89
  model_trt = TRTModule()
90
  model_trt.load_state_dict(torch.load(self.model_path))
91
- logger.info("πŸš€ Using TensorRT as MODEL frameworks!")
92
  except FileNotFoundError:
93
  logger.warning(f"🈳 No found model weight at {self.model_path}")
94
  model_trt = self._create_trt_model()
@@ -102,5 +102,5 @@ class FastModelLoader:
102
  logger.info(f"♻️ Creating TensorRT model")
103
  model_trt = torch2trt(model.cuda(), [dummy_input])
104
  torch.save(model_trt.state_dict(), self.model_path)
105
- logger.info(f"πŸ“₯ TensorRT model saved to {self.model_path}")
106
  return model_trt
 
1
  from pathlib import Path
2
 
3
  import torch
 
4
  from torch import Tensor
5
 
6
  from yolo.config.config import Config
7
  from yolo.model.yolo import create_model
8
+ from yolo.utils.logger import logger
9
 
10
 
11
  class FastModelLoader:
 
21
 
22
  def _validate_compiler(self):
23
  if self.compiler not in ["onnx", "trt", "deploy"]:
24
+ logger.warning(f":warning: Compiler '{self.compiler}' is not supported. Using original model.")
25
  self.compiler = None
26
  if self.cfg.device == "mps" and self.compiler == "trt":
27
+ logger.warning(":red_apple: TensorRT does not support MPS devices. Using original model.")
28
  self.compiler = None
29
 
30
  def load_model(self, device):
 
59
  providers = ["CUDAExecutionProvider"]
60
  try:
61
  ort_session = InferenceSession(self.model_path, providers=providers)
62
+ logger.info(":rocket: Using ONNX as MODEL frameworks!")
63
  except Exception as e:
64
  logger.warning(f"🈳 Error loading ONNX model: {e}")
65
  ort_session = self._create_onnx_model(providers)
 
79
  output_names=["output"],
80
  dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
81
  )
82
+ logger.info(f":inbox_tray: ONNX model saved to {self.model_path}")
83
  return InferenceSession(self.model_path, providers=providers)
84
 
85
  def _load_trt_model(self):
 
88
  try:
89
  model_trt = TRTModule()
90
  model_trt.load_state_dict(torch.load(self.model_path))
91
+ logger.info(":rocket: Using TensorRT as MODEL frameworks!")
92
  except FileNotFoundError:
93
  logger.warning(f"🈳 No found model weight at {self.model_path}")
94
  model_trt = self._create_trt_model()
 
102
  logger.info(f"♻️ Creating TensorRT model")
103
  model_trt = torch2trt(model.cuda(), [dummy_input])
104
  torch.save(model_trt.state_dict(), self.model_path)
105
+ logger.info(f":inbox_tray: TensorRT model saved to {self.model_path}")
106
  return model_trt
yolo/utils/logger.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from lightning.pytorch.utilities.rank_zero import rank_zero_only
4
+ from rich.console import Console
5
+ from rich.logging import RichHandler
6
+
7
+ logger = logging.getLogger("yolo")
8
+ logger.setLevel(logging.DEBUG)
9
+ logger.propagate = False
10
+ if rank_zero_only.rank == 0 and not logger.hasHandlers():
11
+ logger.addHandler(RichHandler(console=Console(), show_level=True, show_path=True, show_time=True, markup=True))
yolo/utils/logging_utils.py CHANGED
@@ -11,55 +11,39 @@ Example:
11
  custom_logger()
12
  """
13
 
14
- import os
15
- import random
16
- import sys
17
  from collections import deque
 
18
  from pathlib import Path
19
  from typing import Any, Dict, List, Optional, Tuple, Union
20
 
21
  import numpy as np
22
  import torch
23
  import wandb
24
- import wandb.errors.term
25
- from loguru import logger
 
 
 
26
  from omegaconf import ListConfig
 
27
  from rich.console import Console, Group
28
- from rich.progress import (
29
- BarColumn,
30
- Progress,
31
- SpinnerColumn,
32
- TextColumn,
33
- TimeRemainingColumn,
34
- )
35
  from rich.table import Table
 
36
  from torch import Tensor
37
  from torch.nn import ModuleList
38
- from torch.optim import Optimizer
39
- from torchvision.transforms.functional import pil_to_tensor
40
 
41
  from yolo.config.config import Config, YOLOLayer
42
  from yolo.model.yolo import YOLO
43
- from yolo.tools.drawer import draw_bboxes
44
  from yolo.utils.solver_utils import make_ap_table
45
 
46
 
47
- def custom_logger(quite: bool = False):
48
- logger.remove()
49
- if quite:
50
- return
51
- logger.add(
52
- sys.stderr,
53
- colorize=True,
54
- format="<fg #003385>[{time:MM/DD HH:mm:ss}]</> <level>{level: ^8}</level>| <level>{message}</level>",
55
- )
56
-
57
-
58
  # TODO: should be moved to correct position
59
  def set_seed(seed):
60
- random.seed(seed)
61
- np.random.seed(seed)
62
- torch.manual_seed(seed)
63
  if torch.cuda.is_available():
64
  torch.cuda.manual_seed(seed)
65
  torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
@@ -67,189 +51,223 @@ def set_seed(seed):
67
  torch.backends.cudnn.benchmark = False
68
 
69
 
70
- class ProgressLogger(Progress):
71
- def __init__(self, cfg: Config, exp_name: str, *args, **kwargs):
72
- set_seed(cfg.lucky_number)
73
- self.local_rank = int(os.getenv("LOCAL_RANK", "0"))
74
- self.quite_mode = self.local_rank or getattr(cfg, "quite", False)
75
- custom_logger(self.quite_mode)
76
- self.save_path = validate_log_directory(cfg, exp_name=cfg.name)
77
-
78
- progress_bar = (
79
- SpinnerColumn(),
80
- TextColumn("[progress.description]{task.description}"),
81
- BarColumn(bar_width=None),
82
- TextColumn("{task.completed:.0f}/{task.total:.0f}"),
83
- TimeRemainingColumn(),
84
- )
85
- self.ap_table = Table()
86
- # TODO: load maxlen by config files
87
- self.ap_past_list = deque(maxlen=5)
88
- self.last_result = 0
89
- super().__init__(*args, *progress_bar, **kwargs)
90
-
91
- self.use_wandb = cfg.use_wandb
92
- if self.use_wandb and self.local_rank == 0:
93
- wandb.errors.term._log = custom_wandb_log
94
- self.wandb = wandb.init(
95
- project="YOLO", resume="allow", mode="online", dir=self.save_path, id=None, name=exp_name
96
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- self.use_tensorboard = cfg.use_tensorboard
99
- if self.use_tensorboard and self.local_rank == 0:
100
- from torch.utils.tensorboard import SummaryWriter
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
- self.tb_writer = SummaryWriter(log_dir=self.save_path / "tensorboard")
103
- logger.opt(colors=True).info(f"πŸ“ Enable TensorBoard locally at <blue><u>http://localhost:6006</></>")
104
 
105
- def rank_check(logging_function):
106
- def wrapper(self, *args, **kwargs):
107
- if getattr(self, "local_rank", 0) != 0:
108
- return
109
- return logging_function(self, *args, **kwargs)
 
 
110
 
111
- return wrapper
112
 
113
- def get_renderable(self):
114
- renderable = Group(*self.get_renderables(), self.ap_table)
115
- return renderable
116
 
117
- @rank_check
118
- def start_train(self, num_epochs: int):
119
- self.task_epoch = self.add_task(f"[cyan]Start Training {num_epochs} epochs", total=num_epochs)
120
- self.update(self.task_epoch, advance=-0.5)
121
-
122
- @rank_check
123
- def start_one_epoch(
124
- self, num_batches: int, task: str = "Train", optimizer: Optimizer = None, epoch_idx: int = None
125
- ):
126
- self.num_batches = num_batches
127
- self.task = task
128
- if hasattr(self, "task_epoch"):
129
- self.update(self.task_epoch, description=f"[cyan] Preparing Data")
130
-
131
- if optimizer is not None:
132
- lr_values = [params["lr"] for params in optimizer.param_groups]
133
- lr_names = ["Learning Rate/bias", "Learning Rate/norm", "Learning Rate/conv"]
134
- if self.use_wandb:
135
- for lr_name, lr_value in zip(lr_names, lr_values):
136
- self.wandb.log({lr_name: lr_value}, step=epoch_idx)
137
-
138
- if self.use_tensorboard:
139
- for lr_name, lr_value in zip(lr_names, lr_values):
140
- self.tb_writer.add_scalar(lr_name, lr_value, global_step=epoch_idx)
141
-
142
- self.batch_task = self.add_task(f"[green] Phase: {task}", total=num_batches)
143
-
144
- @rank_check
145
- def one_batch(self, batch_info: Dict[str, Tensor] = None):
146
- epoch_descript = "[cyan]" + self.task + "[white] |"
147
- batch_descript = "|"
148
- if self.task == "Train":
149
- self.update(self.task_epoch, advance=1 / self.num_batches)
150
- for info_name, info_val in batch_info.items():
151
- epoch_descript += f"{info_name: ^9}|"
152
- batch_descript += f" {info_val:2.2f} |"
153
- self.update(self.batch_task, advance=1, description=f"[green]{self.task} [white]{batch_descript}")
154
- if hasattr(self, "task_epoch"):
155
- self.update(self.task_epoch, description=epoch_descript)
156
-
157
- @rank_check
158
- def finish_one_epoch(self, batch_info: Dict[str, Any] = None, epoch_idx: int = -1):
159
- if self.task == "Train":
160
- prefix = "Loss"
161
- elif self.task == "Validate":
162
- prefix = "Metrics"
163
- batch_info = {f"{prefix}/{key}": value for key, value in batch_info.items()}
164
- if self.use_wandb:
165
- self.wandb.log(batch_info, step=epoch_idx)
166
- if self.use_tensorboard:
167
- for key, value in batch_info.items():
168
- self.tb_writer.add_scalar(key, value, epoch_idx)
169
-
170
- self.remove_task(self.batch_task)
171
-
172
- @rank_check
173
- def visualize_image(
174
- self,
175
- images: Optional[Tensor] = None,
176
- ground_truth: Optional[Tensor] = None,
177
- prediction: Optional[Union[List[Tensor], Tensor]] = None,
178
- epoch_idx: int = 0,
179
- ) -> None:
180
- """
181
- Upload the ground truth bounding boxes, predicted bounding boxes, and the original image to wandb or TensorBoard.
182
-
183
- Args:
184
- images (Optional[Tensor]): Tensor of images with shape (BZ, 3, 640, 640).
185
- ground_truth (Optional[Tensor]): Ground truth bounding boxes with shape (BZ, N, 5) or (N, 5). Defaults to None.
186
- prediction (prediction: Optional[Union[List[Tensor], Tensor]]): List of predicted bounding boxes with shape (N, 6) or (N, 6). Defaults to None.
187
- epoch_idx (int): Current epoch index. Defaults to 0.
188
- """
189
- if images is not None:
190
- images = images[0] if images.ndim == 4 else images
191
- if self.use_wandb:
192
- wandb.log({"Input Image": wandb.Image(images)}, step=epoch_idx)
193
- if self.use_tensorboard:
194
- self.tb_writer.add_image("Media/Input Image", images, 1)
195
-
196
- if ground_truth is not None:
197
- gt_boxes = ground_truth[0] if ground_truth.ndim == 3 else ground_truth
198
- if self.use_wandb:
199
- wandb.log(
200
- {"Ground Truth": wandb.Image(images, boxes={"predictions": {"box_data": log_bbox(gt_boxes)}})},
201
- step=epoch_idx,
202
- )
203
- if self.use_tensorboard:
204
- self.tb_writer.add_image("Media/Ground Truth", pil_to_tensor(draw_bboxes(images, gt_boxes)), epoch_idx)
205
-
206
- if prediction is not None:
207
- pred_boxes = prediction[0] if isinstance(prediction, list) else prediction
208
- if self.use_wandb:
209
- wandb.log(
210
- {"Prediction": wandb.Image(images, boxes={"predictions": {"box_data": log_bbox(pred_boxes)}})},
211
- step=epoch_idx,
212
- )
213
- if self.use_tensorboard:
214
- self.tb_writer.add_image("Media/Prediction", pil_to_tensor(draw_bboxes(images, pred_boxes)), epoch_idx)
215
-
216
- @rank_check
217
- def start_pycocotools(self):
218
- self.batch_task = self.add_task("[green]Run pycocotools", total=1)
219
-
220
- @rank_check
221
- def finish_pycocotools(self, result, epoch_idx=-1):
222
- ap_table, ap_main = make_ap_table(result * 100, self.ap_past_list, self.last_result, epoch_idx)
223
- self.last_result = np.maximum(result, self.last_result)
224
- self.ap_past_list.append((epoch_idx, ap_main))
225
- self.ap_table = ap_table
226
-
227
- if self.use_wandb:
228
- self.wandb.log({"PyCOCO/AP @ .5:.95": ap_main[2], "PyCOCO/AP @ .5": ap_main[5]})
229
- if self.use_tensorboard:
230
- # TODO: waiting torch bugs fix, https://github.com/pytorch/pytorch/issues/32651
231
- self.tb_writer.add_scalar("PyCOCO/AP @ .5:.95", ap_main[2], epoch_idx)
232
- self.tb_writer.add_scalar("PyCOCO/AP @ .5", ap_main[5], epoch_idx)
233
-
234
- self.update(self.batch_task, advance=1)
235
- self.refresh()
236
- self.remove_task(self.batch_task)
237
 
238
- @rank_check
239
- def finish_train(self):
240
- self.remove_task(self.task_epoch)
241
- self.stop()
242
- if self.use_wandb:
243
- self.wandb.finish()
244
- if self.use_tensorboard:
245
- self.tb_writer.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
 
 
 
 
 
 
 
247
 
248
- def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=True, silent=False):
249
- if silent:
250
- return
251
- for line in string.split("\n"):
252
- logger.opt(raw=not newline, colors=True).info("🌐 " + line)
253
 
254
 
255
  def log_model_structure(model: Union[ModuleList, YOLOLayer, YOLO]):
@@ -279,6 +297,7 @@ def log_model_structure(model: Union[ModuleList, YOLOLayer, YOLO]):
279
  console.print(table)
280
 
281
 
 
282
  def validate_log_directory(cfg: Config, exp_name: str) -> Path:
283
  base_path = Path(cfg.out_path, cfg.task.task)
284
  save_path = base_path / exp_name
@@ -296,8 +315,9 @@ def validate_log_directory(cfg: Config, exp_name: str) -> Path:
296
  )
297
 
298
  save_path.mkdir(parents=True, exist_ok=True)
299
- logger.opt(colors=True).info(f"πŸ“„ Created log folder: <u><fg #808080>{save_path}</></>")
300
- logger.add(save_path / "output.log", mode="w", backtrace=True, diagnose=True)
 
301
  return save_path
302
 
303
 
@@ -332,4 +352,4 @@ def log_bbox(
332
  bbox_entry["scores"] = {"confidence": conf[0]}
333
  bbox_list.append(bbox_entry)
334
 
335
- return bbox_list
 
11
  custom_logger()
12
  """
13
 
14
+ import logging
 
 
15
  from collections import deque
16
+ from logging import FileHandler
17
  from pathlib import Path
18
  from typing import Any, Dict, List, Optional, Tuple, Union
19
 
20
  import numpy as np
21
  import torch
22
  import wandb
23
+ from lightning import LightningModule, Trainer, seed_everything
24
+ from lightning.pytorch.callbacks import Callback, RichModelSummary, RichProgressBar
25
+ from lightning.pytorch.callbacks.progress.rich_progress import CustomProgress
26
+ from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
27
+ from lightning.pytorch.utilities import rank_zero_only
28
  from omegaconf import ListConfig
29
+ from rich import get_console, reconfigure
30
  from rich.console import Console, Group
31
+ from rich.logging import RichHandler
 
 
 
 
 
 
32
  from rich.table import Table
33
+ from rich.text import Text
34
  from torch import Tensor
35
  from torch.nn import ModuleList
36
+ from typing_extensions import override
 
37
 
38
  from yolo.config.config import Config, YOLOLayer
39
  from yolo.model.yolo import YOLO
40
+ from yolo.utils.logger import logger
41
  from yolo.utils.solver_utils import make_ap_table
42
 
43
 
 
 
 
 
 
 
 
 
 
 
 
44
  # TODO: should be moved to correct position
45
  def set_seed(seed):
46
+ seed_everything(seed)
 
 
47
  if torch.cuda.is_available():
48
  torch.cuda.manual_seed(seed)
49
  torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
 
51
  torch.backends.cudnn.benchmark = False
52
 
53
 
54
+ class YOLOCustomProgress(CustomProgress):
55
+ def get_renderable(self):
56
+ renderable = Group(*self.get_renderables())
57
+ if hasattr(self, "table"):
58
+ renderable = Group(*self.get_renderables(), self.table)
59
+ return renderable
60
+
61
+
62
+ class YOLORichProgressBar(RichProgressBar):
63
+ @override
64
+ @rank_zero_only
65
+ def _init_progress(self, trainer: "Trainer") -> None:
66
+ if self.is_enabled and (self.progress is None or self._progress_stopped):
67
+ self._reset_progress_bar_ids()
68
+ reconfigure(**self._console_kwargs)
69
+ self._console = Console()
70
+ self._console.clear_live()
71
+ self.progress = YOLOCustomProgress(
72
+ *self.configure_columns(trainer),
73
+ auto_refresh=False,
74
+ disable=self.is_disabled,
75
+ console=self._console,
 
 
 
 
76
  )
77
+ self.progress.start()
78
+
79
+ self._progress_stopped = False
80
+
81
+ self.max_result = 0
82
+ self.past_results = deque(maxlen=5)
83
+ self.progress.table = Table()
84
+
85
+ @override
86
+ def _get_train_description(self, current_epoch: int) -> str:
87
+ return Text("[cyan]Train [white]|")
88
+
89
+ @override
90
+ @rank_zero_only
91
+ def on_train_start(self, trainer, pl_module):
92
+ self._init_progress(trainer)
93
+ num_epochs = trainer.max_epochs - 1
94
+ self.task_epoch = self._add_task(
95
+ total_batches=num_epochs,
96
+ description=f"[cyan]Start Training {num_epochs} epochs",
97
+ )
98
+ self.max_result = 0
99
+ self.past_results.clear()
100
+ self.progress.update(self.task_epoch, advance=-0.5)
101
+
102
+ @override
103
+ @rank_zero_only
104
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch: Any, batch_idx: int):
105
+ self._update(self.train_progress_bar_id, batch_idx + 1)
106
+ self._update_metrics(trainer, pl_module)
107
+ epoch_descript = "[cyan]Train [white]|"
108
+ batch_descript = "[green]Train [white]|"
109
+ metrics = self.get_metrics(trainer, pl_module)
110
+ metrics.pop("v_num")
111
+ for metrics_name, metrics_val in metrics.items():
112
+ if "Loss_step" in metrics_name:
113
+ epoch_descript += f"{metrics_name.removesuffix('_step').split('/')[1]: ^9}|"
114
+ batch_descript += f" {metrics_val:2.2f} |"
115
+
116
+ self.progress.update(self.task_epoch, advance=1 / self.total_train_batches, description=epoch_descript)
117
+ self.progress.update(self.train_progress_bar_id, description=batch_descript)
118
+ self.refresh()
119
 
120
+ @override
121
+ @rank_zero_only
122
+ def on_train_end(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
123
+ self._update_metrics(trainer, pl_module)
124
+ self.progress.remove_task(self.train_progress_bar_id)
125
+ self.train_progress_bar_id = None
126
+
127
+ @override
128
+ @rank_zero_only
129
+ def on_validation_end(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
130
+ if trainer.state.fn == "fit":
131
+ self._update_metrics(trainer, pl_module)
132
+ self.reset_dataloader_idx_tracker()
133
+ all_metrics = self.get_metrics(trainer, pl_module)
134
+
135
+ ap_ar_list = [
136
+ key
137
+ for key in all_metrics.keys()
138
+ if key.startswith(("map", "mar")) and not key.endswith(("_step", "_epoch"))
139
+ ]
140
+ score = np.array([all_metrics[key] for key in ap_ar_list]) * 100
141
+
142
+ self.progress.table, ap_main = make_ap_table(score, self.past_results, self.max_result, trainer.current_epoch)
143
+ self.max_result = np.maximum(score, self.max_result)
144
+ self.past_results.append((trainer.current_epoch, ap_main))
145
+
146
+ @override
147
+ def refresh(self) -> None:
148
+ if self.progress:
149
+ self.progress.refresh()
150
+
151
+ @property
152
+ def validation_description(self) -> str:
153
+ return "[green]Validation"
154
+
155
+
156
+ class YOLORichModelSummary(RichModelSummary):
157
+ @staticmethod
158
+ @override
159
+ def summarize(
160
+ summary_data: List[Tuple[str, List[str]]],
161
+ total_parameters: int,
162
+ trainable_parameters: int,
163
+ model_size: float,
164
+ total_training_modes: Dict[str, int],
165
+ **summarize_kwargs: Any,
166
+ ) -> None:
167
+ from lightning.pytorch.utilities.model_summary import get_human_readable_count
168
 
169
+ console = get_console()
 
170
 
171
+ header_style: str = summarize_kwargs.get("header_style", "bold magenta")
172
+ table = Table(header_style=header_style)
173
+ table.add_column(" ", style="dim")
174
+ table.add_column("Name", justify="left", no_wrap=True)
175
+ table.add_column("Type")
176
+ table.add_column("Params", justify="right")
177
+ table.add_column("Mode")
178
 
179
+ column_names = list(zip(*summary_data))[0]
180
 
181
+ for column_name in ["In sizes", "Out sizes"]:
182
+ if column_name in column_names:
183
+ table.add_column(column_name, justify="right", style="white")
184
 
185
+ rows = list(zip(*(arr[1] for arr in summary_data)))
186
+ for row in rows:
187
+ table.add_row(*row)
188
+
189
+ console.print(table)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
+ parameters = []
192
+ for param in [trainable_parameters, total_parameters - trainable_parameters, total_parameters, model_size]:
193
+ parameters.append("{:<{}}".format(get_human_readable_count(int(param)), 10))
194
+
195
+ grid = Table(header_style=header_style)
196
+ table.add_column(" ", style="dim")
197
+ grid.add_column("[bold]Attributes[/]")
198
+ grid.add_column("Value")
199
+
200
+ grid.add_row("[bold]Trainable params[/]", f"{parameters[0]}")
201
+ grid.add_row("[bold]Non-trainable params[/]", f"{parameters[1]}")
202
+ grid.add_row("[bold]Total params[/]", f"{parameters[2]}")
203
+ grid.add_row("[bold]Total estimated model params size (MB)[/]", f"{parameters[3]}")
204
+ grid.add_row("[bold]Modules in train mode[/]", f"{total_training_modes['train']}")
205
+ grid.add_row("[bold]Modules in eval mode[/]", f"{total_training_modes['eval']}")
206
+
207
+ console.print(grid)
208
+
209
+
210
+ class ImageLogger(Callback):
211
+ def on_validation_batch_end(self, trainer: Trainer, pl_module, outputs, batch, batch_idx) -> None:
212
+ if batch_idx != 0:
213
+ return
214
+ batch_size, images, targets, rev_tensor, img_paths = batch
215
+ gt_boxes = targets[0] if targets.ndim == 3 else targets
216
+ pred_boxes = outputs[0] if isinstance(outputs, list) else outputs
217
+ images = [images[0]]
218
+ step = trainer.current_epoch
219
+ for logger in trainer.loggers:
220
+ if isinstance(logger, WandbLogger):
221
+ logger.log_image("Input Image", images, step=step)
222
+ logger.log_image("Ground Truth", images, step=step, boxes=[log_bbox(gt_boxes)])
223
+ logger.log_image("Prediction", images, step=step, boxes=[log_bbox(pred_boxes)])
224
+
225
+
226
+ def setup_logger(logger_name, quite=False):
227
+ class EmojiFormatter(logging.Formatter):
228
+ def format(self, record, emoji=":high_voltage:"):
229
+ return f"{emoji} {super().format(record)}"
230
+
231
+ rich_handler = RichHandler(markup=True)
232
+ rich_handler.setFormatter(EmojiFormatter("%(message)s"))
233
+ rich_logger = logging.getLogger(logger_name)
234
+ if rich_logger:
235
+ rich_logger.handlers.clear()
236
+ rich_logger.addHandler(rich_handler)
237
+ if quite:
238
+ rich_logger.setLevel(logging.ERROR)
239
+
240
+
241
+ def setup(cfg: Config):
242
+ quite = hasattr(cfg, "quite")
243
+ setup_logger("lightning.fabric", quite=quite)
244
+ setup_logger("lightning.pytorch", quite=quite)
245
+
246
+ def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=True, silent=False):
247
+ if silent:
248
+ return
249
+ for line in string.split("\n"):
250
+ logger.info(Text.from_ansi(":globe_with_meridians: " + line))
251
+
252
+ wandb.errors.term._log = custom_wandb_log
253
+
254
+ save_path = validate_log_directory(cfg, cfg.name)
255
+
256
+ progress, loggers = [], []
257
+
258
+ if quite:
259
+ logger.setLevel(logging.ERROR)
260
+ return progress, loggers, save_path
261
 
262
+ progress.append(YOLORichProgressBar())
263
+ progress.append(YOLORichModelSummary())
264
+ progress.append(ImageLogger())
265
+ if cfg.use_tensorboard:
266
+ loggers.append(TensorBoardLogger(log_graph="all", save_dir=save_path))
267
+ if cfg.use_wandb:
268
+ loggers.append(WandbLogger(project="YOLO", name=cfg.name, save_dir=save_path, id=None))
269
 
270
+ return progress, loggers, save_path
 
 
 
 
271
 
272
 
273
  def log_model_structure(model: Union[ModuleList, YOLOLayer, YOLO]):
 
297
  console.print(table)
298
 
299
 
300
+ @rank_zero_only
301
  def validate_log_directory(cfg: Config, exp_name: str) -> Path:
302
  base_path = Path(cfg.out_path, cfg.task.task)
303
  save_path = base_path / exp_name
 
315
  )
316
 
317
  save_path.mkdir(parents=True, exist_ok=True)
318
+ if not getattr(cfg, "quite", False):
319
+ logger.info(f"πŸ“„ Created log folder: [blue b u]{save_path}[/]")
320
+ logger.addHandler(FileHandler(save_path / "output.log"))
321
  return save_path
322
 
323
 
 
352
  bbox_entry["scores"] = {"confidence": conf[0]}
353
  bbox_list.append(bbox_entry)
354
 
355
+ return {"predictions": {"box_data": bbox_list}}
yolo/utils/model_utils.py CHANGED
@@ -4,7 +4,6 @@ from typing import List, Optional, Type, Union
4
 
5
  import torch
6
  import torch.distributed as dist
7
- from loguru import logger
8
  from omegaconf import ListConfig
9
  from torch import Tensor
10
  from torch.optim import Optimizer
@@ -13,6 +12,7 @@ from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler
13
  from yolo.config.config import IDX_TO_ID, NMSConfig, OptimizerConfig, SchedulerConfig
14
  from yolo.model.yolo import YOLO
15
  from yolo.utils.bounding_box_utils import bbox_nms, transform_bbox
 
16
 
17
 
18
  class ExponentialMovingAverage:
@@ -52,9 +52,9 @@ def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
52
  conv_params = [p for name, p in model.named_parameters() if "weight" in name and "bn" not in name]
53
 
54
  model_parameters = [
55
- {"params": bias_params, "weight_decay": 0},
56
- {"params": conv_params},
57
- {"params": norm_params, "weight_decay": 0},
58
  ]
59
 
60
  def next_epoch(self, batch_num):
@@ -65,12 +65,16 @@ def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
65
 
66
  def next_batch(self):
67
  self.batch_idx += 1
 
68
  for lr_idx, param_group in enumerate(self.param_groups):
69
  min_lr, max_lr = self.min_lr[lr_idx], self.max_lr[lr_idx]
70
  param_group["lr"] = min_lr + (self.batch_idx) * (max_lr - min_lr) / self.batch_num
 
 
71
 
72
  optimizer_class.next_batch = next_batch
73
  optimizer_class.next_epoch = next_epoch
 
74
  optimizer = optimizer_class(model_parameters, **optim_cfg.args)
75
  optimizer.max_lr = [0.1, 0, 0]
76
  return optimizer
@@ -120,7 +124,7 @@ def get_device(device_spec: Union[str, int, List[int]]) -> torch.device:
120
  return device, ddp_flag
121
 
122
 
123
- class PostProccess:
124
  """
125
  TODO: function document
126
  scale back the prediction and do nms for pred_bbox
@@ -168,6 +172,7 @@ def predicts_to_json(img_paths, predicts, rev_tensor):
168
  batch_json = []
169
  for img_path, bboxes, box_reverse in zip(img_paths, predicts, rev_tensor):
170
  scale, shift = box_reverse.split([1, 4])
 
171
  bboxes[:, 1:5] = (bboxes[:, 1:5] - shift[None]) / scale[None]
172
  bboxes[:, 1:5] = transform_bbox(bboxes[:, 1:5], "xyxy -> xywh")
173
  for cls, *pos, conf in bboxes:
 
4
 
5
  import torch
6
  import torch.distributed as dist
 
7
  from omegaconf import ListConfig
8
  from torch import Tensor
9
  from torch.optim import Optimizer
 
12
  from yolo.config.config import IDX_TO_ID, NMSConfig, OptimizerConfig, SchedulerConfig
13
  from yolo.model.yolo import YOLO
14
  from yolo.utils.bounding_box_utils import bbox_nms, transform_bbox
15
+ from yolo.utils.logger import logger
16
 
17
 
18
  class ExponentialMovingAverage:
 
52
  conv_params = [p for name, p in model.named_parameters() if "weight" in name and "bn" not in name]
53
 
54
  model_parameters = [
55
+ {"params": bias_params, "momentum": 0.8, "weight_decay": 0},
56
+ {"params": conv_params, "momentum": 0.8},
57
+ {"params": norm_params, "momentum": 0.8, "weight_decay": 0},
58
  ]
59
 
60
  def next_epoch(self, batch_num):
 
65
 
66
  def next_batch(self):
67
  self.batch_idx += 1
68
+ lr_dict = dict()
69
  for lr_idx, param_group in enumerate(self.param_groups):
70
  min_lr, max_lr = self.min_lr[lr_idx], self.max_lr[lr_idx]
71
  param_group["lr"] = min_lr + (self.batch_idx) * (max_lr - min_lr) / self.batch_num
72
+ lr_dict[f"LR/{lr_idx}"] = param_group["lr"]
73
+ return lr_dict
74
 
75
  optimizer_class.next_batch = next_batch
76
  optimizer_class.next_epoch = next_epoch
77
+
78
  optimizer = optimizer_class(model_parameters, **optim_cfg.args)
79
  optimizer.max_lr = [0.1, 0, 0]
80
  return optimizer
 
124
  return device, ddp_flag
125
 
126
 
127
+ class PostProcess:
128
  """
129
  TODO: function document
130
  scale back the prediction and do nms for pred_bbox
 
172
  batch_json = []
173
  for img_path, bboxes, box_reverse in zip(img_paths, predicts, rev_tensor):
174
  scale, shift = box_reverse.split([1, 4])
175
+ bboxes = bboxes.clone()
176
  bboxes[:, 1:5] = (bboxes[:, 1:5] - shift[None]) / scale[None]
177
  bboxes[:, 1:5] = transform_bbox(bboxes[:, 1:5], "xyxy -> xywh")
178
  for cls, *pos, conf in bboxes:
yolo/utils/solver_utils.py CHANGED
@@ -1,5 +1,6 @@
1
  import contextlib
2
  import io
 
3
 
4
  import numpy as np
5
  from pycocotools.coco import COCO
@@ -17,7 +18,7 @@ def calculate_ap(coco_gt: COCO, pd_path):
17
  return coco_eval.stats
18
 
19
 
20
- def make_ap_table(score, past_result=[], last_score=None, epoch=-1):
21
  ap_table = Table()
22
  ap_table.add_column("Epoch", justify="center", style="white", width=5)
23
  ap_table.add_column("Avg. Precision", justify="left", style="cyan")
@@ -30,7 +31,7 @@ def make_ap_table(score, past_result=[], last_score=None, epoch=-1):
30
  if past_result:
31
  ap_table.add_row()
32
 
33
- color = np.where(last_score <= score, "[green]", "[red]")
34
 
35
  this_ap = ("AP @ .5:.95", color[0], score[0], "AP @ .5", color[1], score[1])
36
  metrics = [
 
1
  import contextlib
2
  import io
3
+ from typing import Dict
4
 
5
  import numpy as np
6
  from pycocotools.coco import COCO
 
18
  return coco_eval.stats
19
 
20
 
21
+ def make_ap_table(score: Dict[str, float], past_result=[], max_result=None, epoch=-1):
22
  ap_table = Table()
23
  ap_table.add_column("Epoch", justify="center", style="white", width=5)
24
  ap_table.add_column("Avg. Precision", justify="left", style="cyan")
 
31
  if past_result:
32
  ap_table.add_row()
33
 
34
+ color = np.where(max_result <= score, "[green]", "[red]")
35
 
36
  this_ap = ("AP @ .5:.95", color[0], score[0], "AP @ .5", color[1], score[1])
37
  metrics = [