henry000 commited on
Commit
8ce9eff
Β·
2 Parent(s): e53ff09 7f8235a

πŸ”€ [Merge] branch 'Lightning'

Browse files
tests/test_utils/test_bounding_box_utils.py CHANGED
@@ -146,23 +146,64 @@ def test_anc2box_autoanchor(inference_v7_cfg: Config):
146
 
147
 
148
  def test_bbox_nms():
149
- cls_dist = tensor(
150
- [[[0.1, 0.7, 0.2], [0.6, 0.3, 0.1]], [[0.4, 0.4, 0.2], [0.5, 0.4, 0.1]]] # Example class distribution
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  )
152
- bbox = tensor(
153
- [[[50, 50, 100, 100], [60, 60, 110, 110]], [[40, 40, 90, 90], [70, 70, 120, 120]]], # Example bounding boxes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  dtype=float32,
155
  )
 
156
  nms_cfg = NMSConfig(min_confidence=0.5, min_iou=0.5)
157
 
158
- expected_output = [
159
- tensor(
 
 
 
 
 
 
 
 
160
  [
161
- [1.0000, 50.0000, 50.0000, 100.0000, 100.0000, 0.6682],
162
- [0.0000, 60.0000, 60.0000, 110.0000, 110.0000, 0.6457],
163
- ]
164
- )
165
- ]
 
 
 
 
166
 
167
  output = bbox_nms(cls_dist, bbox, nms_cfg)
168
 
 
146
 
147
 
148
  def test_bbox_nms():
149
+ cls_dist = torch.tensor(
150
+ [
151
+ [
152
+ [0.7, 0.1, 0.2], # High confidence, class 0
153
+ [0.3, 0.6, 0.1], # High confidence, class 1
154
+ [-3.0, -2.0, -1.0], # low confidence, class 2
155
+ [0.6, 0.2, 0.2], # Medium confidence, class 0
156
+ ],
157
+ [
158
+ [0.55, 0.25, 0.2], # Medium confidence, class 0
159
+ [-4.0, -0.5, -2.0], # low confidence, class 1
160
+ [0.15, 0.2, 0.65], # Medium confidence, class 2
161
+ [0.8, 0.1, 0.1], # High confidence, class 0
162
+ ],
163
+ ],
164
+ dtype=float32,
165
  )
166
+
167
+ bbox = torch.tensor(
168
+ [
169
+ [
170
+ [0, 0, 160, 120], # Overlaps with box 4
171
+ [160, 120, 320, 240],
172
+ [0, 120, 160, 240],
173
+ [16, 12, 176, 132],
174
+ ],
175
+ [
176
+ [0, 0, 160, 120], # Overlaps with box 4
177
+ [160, 120, 320, 240],
178
+ [0, 120, 160, 240],
179
+ [16, 12, 176, 132],
180
+ ],
181
+ ],
182
  dtype=float32,
183
  )
184
+
185
  nms_cfg = NMSConfig(min_confidence=0.5, min_iou=0.5)
186
 
187
+ # Batch 1:
188
+ # - box 1 is kept with class 0 as it has a higher confidence than box 4 i.e. box 4 is filtered out
189
+ # - box 2 is kept with class 1
190
+ # - box 3 is rejected by the confidence filter
191
+ # Batch 2:
192
+ # - box 4 is kept with class 0 as it has a higher confidence than box 1 i.e. box 1 is filtered out
193
+ # - box 2 is rejected by the confidence filter
194
+ # - box 3 is kept with class 2
195
+ expected_output = torch.tensor(
196
+ [
197
  [
198
+ [0.0, 0.0, 0.0, 160.0, 120.0, 0.6682],
199
+ [1.0, 160.0, 120.0, 320.0, 240.0, 0.6457],
200
+ ],
201
+ [
202
+ [0.0, 16.0, 12.0, 176.0, 132.0, 0.6900],
203
+ [2.0, 0.0, 120.0, 160.0, 240.0, 0.6570],
204
+ ],
205
+ ]
206
+ )
207
 
208
  output = bbox_nms(cls_dist, bbox, nms_cfg)
209
 
yolo/__init__.py CHANGED
@@ -2,18 +2,22 @@ from yolo.config.config import Config, NMSConfig
2
  from yolo.model.yolo import create_model
3
  from yolo.tools.data_loader import AugmentationComposer, create_dataloader
4
  from yolo.tools.drawer import draw_bboxes
5
- from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
6
  from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms, create_converter
7
  from yolo.utils.deploy_utils import FastModelLoader
8
- from yolo.utils.logging_utils import ProgressLogger, custom_logger
 
 
 
 
9
  from yolo.utils.model_utils import PostProccess
10
 
11
  all = [
12
  "create_model",
13
  "Config",
14
- "ProgressLogger",
15
  "NMSConfig",
16
- "custom_logger",
17
  "validate_log_directory",
18
  "draw_bboxes",
19
  "Vec2Box",
@@ -21,10 +25,9 @@ all = [
21
  "bbox_nms",
22
  "create_converter",
23
  "AugmentationComposer",
 
24
  "create_dataloader",
25
  "FastModelLoader",
26
- "ModelTester",
27
- "ModelTrainer",
28
- "ModelValidator",
29
  "PostProccess",
30
  ]
 
2
  from yolo.model.yolo import create_model
3
  from yolo.tools.data_loader import AugmentationComposer, create_dataloader
4
  from yolo.tools.drawer import draw_bboxes
5
+ from yolo.tools.solver import TrainModel
6
  from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms, create_converter
7
  from yolo.utils.deploy_utils import FastModelLoader
8
+ from yolo.utils.logging_utils import (
9
+ ImageLogger,
10
+ YOLORichModelSummary,
11
+ YOLORichProgressBar,
12
+ )
13
  from yolo.utils.model_utils import PostProccess
14
 
15
  all = [
16
  "create_model",
17
  "Config",
18
+ "YOLORichProgressBar",
19
  "NMSConfig",
20
+ "YOLORichModelSummary",
21
  "validate_log_directory",
22
  "draw_bboxes",
23
  "Vec2Box",
 
25
  "bbox_nms",
26
  "create_converter",
27
  "AugmentationComposer",
28
+ "ImageLogger",
29
  "create_dataloader",
30
  "FastModelLoader",
31
+ "TrainModel",
 
 
32
  "PostProccess",
33
  ]
yolo/config/general.yaml CHANGED
@@ -7,7 +7,7 @@ out_path: runs
7
  exist_ok: True
8
 
9
  lucky_number: 10
10
- use_wandb: False
11
  use_tensorboard: False
12
 
13
  weight: True # Path to weight or True for auto, False for no pretrained weight
 
7
  exist_ok: True
8
 
9
  lucky_number: 10
10
+ use_wandb: True
11
  use_tensorboard: False
12
 
13
  weight: True # Path to weight or True for auto, False for no pretrained weight
yolo/config/task/inference.yaml CHANGED
@@ -8,4 +8,4 @@ data:
8
  nms:
9
  min_confidence: 0.5
10
  min_iou: 0.5
11
- # save_predict: True
 
8
  nms:
9
  min_confidence: 0.5
10
  min_iou: 0.5
11
+ save_predict: True
yolo/config/task/validation.yaml CHANGED
@@ -8,5 +8,5 @@ data:
8
  pin_memory: True
9
  data_augment: {}
10
  nms:
11
- min_confidence: 0.05
12
- min_iou: 0.9
 
8
  pin_memory: True
9
  data_augment: {}
10
  nms:
11
+ min_confidence: 0.0001
12
+ min_iou: 0.7
yolo/lazy.py CHANGED
@@ -2,41 +2,41 @@ import sys
2
  from pathlib import Path
3
 
4
  import hydra
 
5
 
6
  project_root = Path(__file__).resolve().parent.parent
7
  sys.path.append(str(project_root))
8
 
9
  from yolo.config.config import Config
10
- from yolo.model.yolo import create_model
11
- from yolo.tools.data_loader import create_dataloader
12
- from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
13
- from yolo.utils.bounding_box_utils import create_converter
14
- from yolo.utils.deploy_utils import FastModelLoader
15
- from yolo.utils.logging_utils import ProgressLogger
16
- from yolo.utils.model_utils import get_device
17
 
18
 
19
  @hydra.main(config_path="config", config_name="config", version_base=None)
20
  def main(cfg: Config):
21
- progress = ProgressLogger(cfg, exp_name=cfg.name)
22
- device, use_ddp = get_device(cfg.device)
23
- dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task, use_ddp)
24
- if getattr(cfg.task, "fast_inference", False):
25
- model = FastModelLoader(cfg).load_model(device)
26
- else:
27
- model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight)
28
- model = model.to(device)
29
-
30
- converter = create_converter(cfg.model.name, model, cfg.model.anchor, cfg.image_size, device)
31
-
32
- if cfg.task.task == "train":
33
- solver = ModelTrainer(cfg, model, converter, progress, device, use_ddp)
34
- if cfg.task.task == "validation":
35
- solver = ModelValidator(cfg.task, cfg.dataset, model, converter, progress, device)
36
- if cfg.task.task == "inference":
37
- solver = ModelTester(cfg, model, converter, progress, device)
38
- progress.start()
39
- solver.solve(dataloader)
 
 
 
 
40
 
41
 
42
  if __name__ == "__main__":
 
2
  from pathlib import Path
3
 
4
  import hydra
5
+ from lightning import Trainer
6
 
7
  project_root = Path(__file__).resolve().parent.parent
8
  sys.path.append(str(project_root))
9
 
10
  from yolo.config.config import Config
11
+ from yolo.tools.solver import InferenceModel, TrainModel, ValidateModel
12
+ from yolo.utils.logging_utils import setup
 
 
 
 
 
13
 
14
 
15
  @hydra.main(config_path="config", config_name="config", version_base=None)
16
  def main(cfg: Config):
17
+ callbacks, loggers = setup(cfg)
18
+
19
+ trainer = Trainer(
20
+ accelerator="cuda",
21
+ max_epochs=getattr(cfg.task, "epoch", None),
22
+ precision="16-mixed",
23
+ callbacks=callbacks,
24
+ logger=loggers,
25
+ log_every_n_steps=1,
26
+ gradient_clip_val=10,
27
+ deterministic=True,
28
+ )
29
+
30
+ match cfg.task.task:
31
+ case "train":
32
+ model = TrainModel(cfg)
33
+ trainer.fit(model)
34
+ case "validation":
35
+ model = ValidateModel(cfg)
36
+ trainer.validate(model)
37
+ case "inference":
38
+ model = InferenceModel(cfg)
39
+ trainer.predict(model)
40
 
41
 
42
  if __name__ == "__main__":
yolo/model/module.py CHANGED
@@ -3,10 +3,10 @@ from typing import Any, Dict, List, Optional, Tuple
3
  import torch
4
  import torch.nn.functional as F
5
  from einops import rearrange
6
- from loguru import logger
7
  from torch import Tensor, nn
8
  from torch.nn.common_types import _size_2_t
9
 
 
10
  from yolo.utils.module_utils import auto_pad, create_activation_function, round_up
11
 
12
 
 
3
  import torch
4
  import torch.nn.functional as F
5
  from einops import rearrange
 
6
  from torch import Tensor, nn
7
  from torch.nn.common_types import _size_2_t
8
 
9
+ from yolo.utils.logger import logger
10
  from yolo.utils.module_utils import auto_pad, create_activation_function, round_up
11
 
12
 
yolo/model/yolo.py CHANGED
@@ -3,12 +3,12 @@ from pathlib import Path
3
  from typing import Dict, List, Union
4
 
5
  import torch
6
- from loguru import logger
7
  from omegaconf import ListConfig, OmegaConf
8
  from torch import nn
9
 
10
  from yolo.config.config import ModelConfig, YOLOLayer
11
  from yolo.tools.dataset_preparation import prepare_weight
 
12
  from yolo.utils.module_utils import get_layer_map
13
 
14
 
@@ -32,10 +32,10 @@ class YOLO(nn.Module):
32
  def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
33
  self.layer_index = {}
34
  output_dim, layer_idx = [3], 1
35
- logger.info(f"🚜 Building YOLO")
36
  for arch_name in model_arch:
37
  if model_arch[arch_name]:
38
- logger.info(f" πŸ—οΈ Building {arch_name}")
39
  for layer_idx, layer_spec in enumerate(model_arch[arch_name], start=layer_idx):
40
  layer_type, layer_info = next(iter(layer_spec.items()))
41
  layer_args = layer_info.get("args", {})
@@ -123,7 +123,7 @@ class YOLO(nn.Module):
123
  weights: A OrderedDict containing the new weights.
124
  """
125
  if isinstance(weights, Path):
126
- weights = torch.load(weights, map_location=torch.device("cpu"))
127
  if "model_state_dict" in weights:
128
  weights = weights["model_state_dict"]
129
 
@@ -144,7 +144,7 @@ class YOLO(nn.Module):
144
 
145
  for error_name, error_set in error_dict.items():
146
  for weight_name in error_set:
147
- logger.warning(f"⚠️ Weight {error_name} for key: {'.'.join(weight_name)}")
148
 
149
  self.model.load_state_dict(model_state_dict)
150
 
@@ -171,7 +171,7 @@ def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True,
171
  prepare_weight(weight_path=weight_path)
172
  if weight_path.exists():
173
  model.save_load_weights(weight_path)
174
- logger.info("βœ… Success load model & weight")
175
  else:
176
- logger.info("βœ… Success load model")
177
  return model
 
3
  from typing import Dict, List, Union
4
 
5
  import torch
 
6
  from omegaconf import ListConfig, OmegaConf
7
  from torch import nn
8
 
9
  from yolo.config.config import ModelConfig, YOLOLayer
10
  from yolo.tools.dataset_preparation import prepare_weight
11
+ from yolo.utils.logger import logger
12
  from yolo.utils.module_utils import get_layer_map
13
 
14
 
 
32
  def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
33
  self.layer_index = {}
34
  output_dim, layer_idx = [3], 1
35
+ logger.info(f":tractor: Building YOLO")
36
  for arch_name in model_arch:
37
  if model_arch[arch_name]:
38
+ logger.info(f" :building_construction: Building {arch_name}")
39
  for layer_idx, layer_spec in enumerate(model_arch[arch_name], start=layer_idx):
40
  layer_type, layer_info = next(iter(layer_spec.items()))
41
  layer_args = layer_info.get("args", {})
 
123
  weights: A OrderedDict containing the new weights.
124
  """
125
  if isinstance(weights, Path):
126
+ weights = torch.load(weights, map_location=torch.device("cpu"), weights_only=False)
127
  if "model_state_dict" in weights:
128
  weights = weights["model_state_dict"]
129
 
 
144
 
145
  for error_name, error_set in error_dict.items():
146
  for weight_name in error_set:
147
+ logger.warning(f":warning: Weight {error_name} for key: {'.'.join(weight_name)}")
148
 
149
  self.model.load_state_dict(model_state_dict)
150
 
 
171
  prepare_weight(weight_path=weight_path)
172
  if weight_path.exists():
173
  model.save_load_weights(weight_path)
174
+ logger.info(":white_check_mark: Success load model & weight")
175
  else:
176
+ logger.info(":white_check_mark: Success load model")
177
  return model
yolo/tools/data_loader.py CHANGED
@@ -5,12 +5,10 @@ from typing import Generator, List, Tuple, Union
5
 
6
  import numpy as np
7
  import torch
8
- from loguru import logger
9
  from PIL import Image
10
  from rich.progress import track
11
  from torch import Tensor
12
  from torch.utils.data import DataLoader, Dataset
13
- from torch.utils.data.distributed import DistributedSampler
14
 
15
  from yolo.config.config import DataConfig, DatasetConfig
16
  from yolo.tools.data_augmentation import *
@@ -20,7 +18,9 @@ from yolo.utils.dataset_utils import (
20
  create_image_metadata,
21
  locate_label_paths,
22
  scale_segmentation,
 
23
  )
 
24
 
25
 
26
  class YoloDataset(Dataset):
@@ -32,7 +32,8 @@ class YoloDataset(Dataset):
32
  transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
33
  self.transform = AugmentationComposer(transforms, self.image_size)
34
  self.transform.get_more_data = self.get_more_data
35
- self.data = self.load_data(Path(dataset_cfg.path), phase_name)
 
36
 
37
  def load_data(self, dataset_path: Path, phase_name: str):
38
  """
@@ -45,15 +46,15 @@ class YoloDataset(Dataset):
45
  Returns:
46
  dict: The loaded data from the cache for the specified phase.
47
  """
48
- cache_path = dataset_path / f"{phase_name}.cache"
49
 
50
  if not cache_path.exists():
51
- logger.info("🏭 Generating {} cache", phase_name)
52
  data = self.filter_data(dataset_path, phase_name)
53
  torch.save(data, cache_path)
54
  else:
55
  data = torch.load(cache_path, weights_only=False)
56
- logger.info("πŸ“¦ Loaded {} cache", phase_name)
57
  return data
58
 
59
  def filter_data(self, dataset_path: Path, phase_name: str) -> list:
@@ -103,7 +104,7 @@ class YoloDataset(Dataset):
103
  img_path = images_path / image_name
104
  data.append((img_path, labels))
105
  valid_inputs += 1
106
- logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
107
  return data
108
 
109
  def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[Tensor, None]:
@@ -132,9 +133,11 @@ class YoloDataset(Dataset):
132
  return torch.zeros((0, 5))
133
 
134
  def get_data(self, idx):
135
- img_path, bboxes = self.data[idx]
136
- img = Image.open(img_path).convert("RGB")
137
- return img, bboxes, img_path
 
 
138
 
139
  def get_more_data(self, num: int = 1):
140
  indices = torch.randint(0, len(self), (num,))
@@ -143,67 +146,59 @@ class YoloDataset(Dataset):
143
  def __getitem__(self, idx) -> Tuple[Image.Image, Tensor, Tensor, List[str]]:
144
  img, bboxes, img_path = self.get_data(idx)
145
  img, bboxes, rev_tensor = self.transform(img, bboxes)
 
 
146
  return img, bboxes, rev_tensor, img_path
147
 
148
  def __len__(self) -> int:
149
- return len(self.data)
150
-
151
-
152
- class YoloDataLoader(DataLoader):
153
- def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train", use_ddp: bool = False):
154
- """Initializes the YoloDataLoader with hydra-config files."""
155
- dataset = YoloDataset(data_cfg, dataset_cfg, task)
156
- sampler = DistributedSampler(dataset, shuffle=data_cfg.shuffle) if use_ddp else None
157
- self.image_size = data_cfg.image_size[0]
158
- super().__init__(
159
- dataset,
160
- batch_size=data_cfg.batch_size,
161
- sampler=sampler,
162
- shuffle=data_cfg.shuffle and not use_ddp,
163
- num_workers=data_cfg.cpu_num,
164
- pin_memory=data_cfg.pin_memory,
165
- collate_fn=self.collate_fn,
166
- )
167
-
168
- def collate_fn(self, batch: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tensor]]:
169
- """
170
- A collate function to handle batching of images and their corresponding targets.
171
 
172
- Args:
173
- batch (list of tuples): Each tuple contains:
174
- - image (Tensor): The image tensor.
175
- - labels (Tensor): The tensor of labels for the image.
176
 
177
- Returns:
178
- Tuple[Tensor, List[Tensor]]: A tuple containing:
179
- - A tensor of batched images.
180
- - A list of tensors, each corresponding to bboxes for each image in the batch.
181
- """
182
- batch_size = len(batch)
183
- target_sizes = [item[1].size(0) for item in batch]
184
- # TODO: Improve readability of these proccess
185
- # TODO: remove maxBbox or reduce loss function memory usage
186
- batch_targets = torch.zeros(batch_size, min(max(target_sizes), 100), 5)
187
- batch_targets[:, :, 0] = -1
188
- for idx, target_size in enumerate(target_sizes):
189
- batch_targets[idx, : min(target_size, 100)] = batch[idx][1][:100]
190
- batch_targets[:, :, 1:] *= self.image_size
191
 
192
- batch_images, _, batch_reverse, batch_path = zip(*batch)
193
- batch_images = torch.stack(batch_images)
194
- batch_reverse = torch.stack(batch_reverse)
 
195
 
196
- return batch_size, batch_images, batch_targets, batch_reverse, batch_path
 
 
 
 
 
 
 
 
 
 
 
 
197
 
 
 
 
198
 
199
- def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train", use_ddp: bool = False):
 
 
 
200
  if task == "inference":
201
  return StreamDataLoader(data_cfg)
202
 
203
  if dataset_cfg.auto_download:
204
  prepare_dataset(dataset_cfg, task)
205
-
206
- return YoloDataLoader(data_cfg, dataset_cfg, task, use_ddp)
 
 
 
 
 
 
 
207
 
208
 
209
  class StreamDataLoader:
 
5
 
6
  import numpy as np
7
  import torch
 
8
  from PIL import Image
9
  from rich.progress import track
10
  from torch import Tensor
11
  from torch.utils.data import DataLoader, Dataset
 
12
 
13
  from yolo.config.config import DataConfig, DatasetConfig
14
  from yolo.tools.data_augmentation import *
 
18
  create_image_metadata,
19
  locate_label_paths,
20
  scale_segmentation,
21
+ tensorlize,
22
  )
23
+ from yolo.utils.logger import logger
24
 
25
 
26
  class YoloDataset(Dataset):
 
32
  transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
33
  self.transform = AugmentationComposer(transforms, self.image_size)
34
  self.transform.get_more_data = self.get_more_data
35
+ img_paths, bboxes = tensorlize(self.load_data(Path(dataset_cfg.path), phase_name))
36
+ self.img_paths, self.bboxes = img_paths, bboxes
37
 
38
  def load_data(self, dataset_path: Path, phase_name: str):
39
  """
 
46
  Returns:
47
  dict: The loaded data from the cache for the specified phase.
48
  """
49
+ cache_path = dataset_path / f"{phase_name}.cache1"
50
 
51
  if not cache_path.exists():
52
+ logger.info(f":factory: Generating {phase_name} cache")
53
  data = self.filter_data(dataset_path, phase_name)
54
  torch.save(data, cache_path)
55
  else:
56
  data = torch.load(cache_path, weights_only=False)
57
+ logger.info(f":package: Loaded {phase_name} cache")
58
  return data
59
 
60
  def filter_data(self, dataset_path: Path, phase_name: str) -> list:
 
104
  img_path = images_path / image_name
105
  data.append((img_path, labels))
106
  valid_inputs += 1
107
+ logger.info(f"Recorded {valid_inputs}/{len(images_list)} valid inputs")
108
  return data
109
 
110
  def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[Tensor, None]:
 
133
  return torch.zeros((0, 5))
134
 
135
  def get_data(self, idx):
136
+ img_path, bboxes = self.img_paths[idx], self.bboxes[idx]
137
+ valid_mask = bboxes[:, 0] != -1
138
+ with Image.open(img_path) as img:
139
+ img = img.convert("RGB")
140
+ return img, torch.from_numpy(bboxes[valid_mask]), img_path
141
 
142
  def get_more_data(self, num: int = 1):
143
  indices = torch.randint(0, len(self), (num,))
 
146
  def __getitem__(self, idx) -> Tuple[Image.Image, Tensor, Tensor, List[str]]:
147
  img, bboxes, img_path = self.get_data(idx)
148
  img, bboxes, rev_tensor = self.transform(img, bboxes)
149
+ bboxes[:, [1, 3]] *= self.image_size[0]
150
+ bboxes[:, [2, 4]] *= self.image_size[1]
151
  return img, bboxes, rev_tensor, img_path
152
 
153
  def __len__(self) -> int:
154
+ return len(self.bboxes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
 
 
 
 
156
 
157
+ def collate_fn(batch: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tensor]]:
158
+ """
159
+ A collate function to handle batching of images and their corresponding targets.
 
 
 
 
 
 
 
 
 
 
 
160
 
161
+ Args:
162
+ batch (list of tuples): Each tuple contains:
163
+ - image (Tensor): The image tensor.
164
+ - labels (Tensor): The tensor of labels for the image.
165
 
166
+ Returns:
167
+ Tuple[Tensor, List[Tensor]]: A tuple containing:
168
+ - A tensor of batched images.
169
+ - A list of tensors, each corresponding to bboxes for each image in the batch.
170
+ """
171
+ batch_size = len(batch)
172
+ target_sizes = [item[1].size(0) for item in batch]
173
+ # TODO: Improve readability of these proccess
174
+ # TODO: remove maxBbox or reduce loss function memory usage
175
+ batch_targets = torch.zeros(batch_size, min(max(target_sizes), 100), 5)
176
+ batch_targets[:, :, 0] = -1
177
+ for idx, target_size in enumerate(target_sizes):
178
+ batch_targets[idx, : min(target_size, 100)] = batch[idx][1][:100]
179
 
180
+ batch_images, _, batch_reverse, batch_path = zip(*batch)
181
+ batch_images = torch.stack(batch_images)
182
+ batch_reverse = torch.stack(batch_reverse)
183
 
184
+ return batch_size, batch_images, batch_targets, batch_reverse, batch_path
185
+
186
+
187
+ def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train"):
188
  if task == "inference":
189
  return StreamDataLoader(data_cfg)
190
 
191
  if dataset_cfg.auto_download:
192
  prepare_dataset(dataset_cfg, task)
193
+ dataset = YoloDataset(data_cfg, dataset_cfg, task)
194
+
195
+ return DataLoader(
196
+ dataset,
197
+ batch_size=data_cfg.batch_size,
198
+ num_workers=data_cfg.cpu_num,
199
+ pin_memory=data_cfg.pin_memory,
200
+ collate_fn=collate_fn,
201
+ )
202
 
203
 
204
  class StreamDataLoader:
yolo/tools/dataset_preparation.py CHANGED
@@ -3,10 +3,10 @@ from pathlib import Path
3
  from typing import Optional
4
 
5
  import requests
6
- from loguru import logger
7
  from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
8
 
9
  from yolo.config.config import DatasetConfig
 
10
 
11
 
12
  def download_file(url, destination: Path):
@@ -30,7 +30,7 @@ def download_file(url, destination: Path):
30
  for data in response.iter_content(chunk_size=1024 * 1024): # 1 MB chunks
31
  file.write(data)
32
  progress.update(task, advance=len(data))
33
- logger.info("βœ… Download completed.")
34
 
35
 
36
  def unzip_file(source: Path, destination: Path):
@@ -71,7 +71,7 @@ def prepare_dataset(dataset_cfg: DatasetConfig, task: str):
71
 
72
  final_place.mkdir(parents=True, exist_ok=True)
73
  if check_files(final_place, dataset_args.get("file_num")):
74
- logger.info(f"βœ… Dataset {dataset_type: <12} already verified.")
75
  continue
76
 
77
  if not local_zip_path.exists():
 
3
  from typing import Optional
4
 
5
  import requests
 
6
  from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
7
 
8
  from yolo.config.config import DatasetConfig
9
+ from yolo.utils.logger import logger
10
 
11
 
12
  def download_file(url, destination: Path):
 
30
  for data in response.iter_content(chunk_size=1024 * 1024): # 1 MB chunks
31
  file.write(data)
32
  progress.update(task, advance=len(data))
33
+ logger.info(":white_check_mark: Download completed.")
34
 
35
 
36
  def unzip_file(source: Path, destination: Path):
 
71
 
72
  final_place.mkdir(parents=True, exist_ok=True)
73
  if check_files(final_place, dataset_args.get("file_num")):
74
+ logger.info(f":white_check_mark: Dataset {dataset_type: <12} already verified.")
75
  continue
76
 
77
  if not local_zip_path.exists():
yolo/tools/drawer.py CHANGED
@@ -3,12 +3,12 @@ from typing import List, Optional, Union
3
 
4
  import numpy as np
5
  import torch
6
- from loguru import logger
7
  from PIL import Image, ImageDraw, ImageFont
8
  from torchvision.transforms.functional import to_pil_image
9
 
10
  from yolo.config.config import ModelConfig
11
  from yolo.model.yolo import YOLO
 
12
 
13
 
14
  def draw_bboxes(
@@ -121,6 +121,6 @@ def draw_model(*, model_cfg: ModelConfig = None, model: YOLO = None, v7_base=Fal
121
  dot.edge(str(idx), str(jdx))
122
  try:
123
  dot.render("Model-arch", format="png", cleanup=True)
124
- logger.info("🎨 Drawing Model Architecture at Model-arch.png")
125
  except:
126
- logger.warning("⚠️ Could not find graphviz backend, continue without drawing the model architecture")
 
3
 
4
  import numpy as np
5
  import torch
 
6
  from PIL import Image, ImageDraw, ImageFont
7
  from torchvision.transforms.functional import to_pil_image
8
 
9
  from yolo.config.config import ModelConfig
10
  from yolo.model.yolo import YOLO
11
+ from yolo.utils.logger import logger
12
 
13
 
14
  def draw_bboxes(
 
121
  dot.edge(str(idx), str(jdx))
122
  try:
123
  dot.render("Model-arch", format="png", cleanup=True)
124
+ logger.info(":artist_palette: Drawing Model Architecture at Model-arch.png")
125
  except:
126
+ logger.warning(":warning: Could not find graphviz backend, continue without drawing the model architecture")
yolo/tools/loss_functions.py CHANGED
@@ -2,12 +2,12 @@ from typing import Any, Dict, List, Tuple
2
 
3
  import torch
4
  import torch.nn.functional as F
5
- from loguru import logger
6
  from torch import Tensor, nn
7
  from torch.nn import BCEWithLogitsLoss
8
 
9
  from yolo.config.config import Config, LossConfig
10
  from yolo.utils.bounding_box_utils import BoxMatcher, Vec2Box, calculate_iou
 
11
 
12
 
13
  class BCELoss(nn.Module):
@@ -124,17 +124,19 @@ class DualLoss:
124
  aux_iou, aux_dfl, aux_cls = self.loss(aux_predicts, targets)
125
  main_iou, main_dfl, main_cls = self.loss(main_predicts, targets)
126
 
 
 
 
 
 
127
  loss_dict = {
128
- "BoxLoss": self.iou_rate * (aux_iou * self.aux_rate + main_iou),
129
- "DFLoss": self.dfl_rate * (aux_dfl * self.aux_rate + main_dfl),
130
- "BCELoss": self.cls_rate * (aux_cls * self.aux_rate + main_cls),
131
  }
132
- loss_sum = sum(list(loss_dict.values())) / len(loss_dict)
133
- return loss_sum, loss_dict
134
 
135
 
136
  def create_loss_function(cfg: Config, vec2box) -> DualLoss:
137
  # TODO: make it flexible, if cfg doesn't contain aux, only use SingleLoss
138
  loss_function = DualLoss(cfg, vec2box)
139
- logger.info("βœ… Success load loss function")
140
  return loss_function
 
2
 
3
  import torch
4
  import torch.nn.functional as F
 
5
  from torch import Tensor, nn
6
  from torch.nn import BCEWithLogitsLoss
7
 
8
  from yolo.config.config import Config, LossConfig
9
  from yolo.utils.bounding_box_utils import BoxMatcher, Vec2Box, calculate_iou
10
+ from yolo.utils.logger import logger
11
 
12
 
13
  class BCELoss(nn.Module):
 
124
  aux_iou, aux_dfl, aux_cls = self.loss(aux_predicts, targets)
125
  main_iou, main_dfl, main_cls = self.loss(main_predicts, targets)
126
 
127
+ total_loss = [
128
+ self.iou_rate * (aux_iou * self.aux_rate + main_iou),
129
+ self.dfl_rate * (aux_dfl * self.aux_rate + main_dfl),
130
+ self.cls_rate * (aux_cls * self.aux_rate + main_cls),
131
+ ]
132
  loss_dict = {
133
+ f"Loss/{name}Loss": value.detach().item() for name, value in zip(["Box", "DFL", "BCE"], total_loss)
 
 
134
  }
135
+ return sum(total_loss), loss_dict
 
136
 
137
 
138
  def create_loss_function(cfg: Config, vec2box) -> DualLoss:
139
  # TODO: make it flexible, if cfg doesn't contain aux, only use SingleLoss
140
  loss_function = DualLoss(cfg, vec2box)
141
+ logger.info(":white_check_mark: Success load loss function")
142
  return loss_function
yolo/tools/solver.py CHANGED
@@ -1,267 +1,154 @@
1
- import contextlib
2
- import io
3
- import json
4
- import os
5
  import time
6
- from collections import defaultdict
7
  from pathlib import Path
8
- from typing import Dict, Optional
9
 
10
- import torch
11
- from loguru import logger
12
- from pycocotools.coco import COCO
13
- from torch import Tensor, distributed
14
- from torch.cuda.amp import GradScaler, autocast
15
- from torch.nn.parallel import DistributedDataParallel as DDP
16
- from torch.utils.data import DataLoader
17
 
18
- from yolo.config.config import Config, DatasetConfig, TrainConfig, ValidationConfig
19
- from yolo.model.yolo import YOLO
20
- from yolo.tools.data_loader import StreamDataLoader, create_dataloader
21
- from yolo.tools.drawer import draw_bboxes, draw_model
22
  from yolo.tools.loss_functions import create_loss_function
23
- from yolo.utils.bounding_box_utils import Vec2Box, calculate_map
24
- from yolo.utils.dataset_utils import locate_label_paths
25
- from yolo.utils.logging_utils import ProgressLogger, log_model_structure
26
- from yolo.utils.model_utils import (
27
- ExponentialMovingAverage,
28
- PostProccess,
29
- collect_prediction,
30
- create_optimizer,
31
- create_scheduler,
32
- predicts_to_json,
33
- )
34
- from yolo.utils.solver_utils import calculate_ap
35
 
36
 
37
- class ModelTrainer:
38
- def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device, use_ddp: bool):
39
- train_cfg: TrainConfig = cfg.task
40
- self.model = model if not use_ddp else DDP(model, device_ids=[device])
41
- self.use_ddp = use_ddp
42
- self.vec2box = vec2box
43
- self.device = device
44
- self.optimizer = create_optimizer(model, train_cfg.optimizer)
45
- self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
46
- self.loss_fn = create_loss_function(cfg, vec2box)
47
- self.progress = progress
48
- self.num_epochs = cfg.task.epoch
49
- self.mAPs_dict = defaultdict(list)
50
 
51
- self.weights_dir = self.progress.save_path / "weights"
52
- self.weights_dir.mkdir(exist_ok=True)
53
 
54
- if not progress.quite_mode:
55
- log_model_structure(model.model)
56
- draw_model(model=model)
57
 
58
- self.validation_dataloader = create_dataloader(
59
- cfg.task.validation.data, cfg.dataset, cfg.task.validation.task, use_ddp
60
- )
61
- self.validator = ModelValidator(cfg.task.validation, cfg.dataset, model, vec2box, progress, device)
62
-
63
- if getattr(train_cfg.ema, "enabled", False):
64
- self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
65
  else:
66
- self.ema = None
67
- self.scaler = GradScaler()
68
-
69
- def train_one_batch(self, images: Tensor, targets: Tensor):
70
- images, targets = images.to(self.device), targets.to(self.device)
71
- self.optimizer.zero_grad()
72
-
73
- with autocast():
74
- predicts = self.model(images)
75
- aux_predicts = self.vec2box(predicts["AUX"])
76
- main_predicts = self.vec2box(predicts["Main"])
77
- loss, loss_item = self.loss_fn(aux_predicts, main_predicts, targets)
78
-
79
- self.scaler.scale(loss).backward()
80
- self.scaler.unscale_(self.optimizer)
81
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
82
- self.scaler.step(self.optimizer)
83
- self.scaler.update()
84
-
85
- return loss_item
86
-
87
- def train_one_epoch(self, dataloader):
88
- self.model.train()
89
- total_loss = defaultdict(float)
90
- total_samples = 0
91
- self.optimizer.next_epoch(len(dataloader))
92
- for batch_size, images, targets, *_ in dataloader:
93
- self.optimizer.next_batch()
94
- loss_each = self.train_one_batch(images, targets)
95
-
96
- for loss_name, loss_val in loss_each.items():
97
- if self.use_ddp: # collecting loss for each batch
98
- distributed.all_reduce(loss_val, op=distributed.ReduceOp.AVG)
99
- total_loss[loss_name] += loss_val.item() * batch_size
100
- total_samples += batch_size
101
- self.progress.one_batch(loss_each)
102
-
103
- for loss_val in total_loss.values():
104
- loss_val /= total_samples
105
-
106
- if self.scheduler:
107
- self.scheduler.step()
108
-
109
- return total_loss
110
-
111
- def save_checkpoint(self, epoch_idx: int, file_name: Optional[str] = None):
112
- file_name = file_name or f"E{epoch_idx:03d}.pt"
113
- file_path = self.weights_dir / file_name
114
-
115
- checkpoint = {
116
- "epoch": epoch_idx,
117
- "model_state_dict": self.model.state_dict(),
118
- "optimizer_state_dict": self.optimizer.state_dict(),
119
- }
120
- if self.ema:
121
- self.ema.apply_shadow()
122
- checkpoint["model_state_dict_ema"] = self.model.state_dict()
123
- self.ema.restore()
124
-
125
- logger.info(f"πŸ’Ύ success save at {file_path}")
126
- torch.save(checkpoint, file_path)
127
-
128
- def good_epoch(self, mAPs: Dict[str, Tensor]) -> bool:
129
- save_flag = True
130
- for mAP_key, mAP_val in mAPs.items():
131
- self.mAPs_dict[mAP_key].append(mAP_val)
132
- if mAP_val < max(self.mAPs_dict[mAP_key]):
133
- save_flag = False
134
- return save_flag
135
-
136
- def solve(self, dataloader: DataLoader):
137
- logger.info("πŸš„ Start Training!")
138
- num_epochs = self.num_epochs
139
-
140
- self.progress.start_train(num_epochs)
141
- for epoch_idx in range(num_epochs):
142
- if self.use_ddp:
143
- dataloader.sampler.set_epoch(epoch_idx)
144
-
145
- self.progress.start_one_epoch(len(dataloader), "Train", self.optimizer, epoch_idx)
146
- epoch_loss = self.train_one_epoch(dataloader)
147
- self.progress.finish_one_epoch(epoch_loss, epoch_idx=epoch_idx)
148
-
149
- mAPs = self.validator.solve(self.validation_dataloader, epoch_idx=epoch_idx)
150
- if mAPs is not None and self.good_epoch(mAPs):
151
- self.save_checkpoint(epoch_idx=epoch_idx)
152
- # TODO: save model if result are better than before
153
- self.progress.finish_train()
154
-
155
-
156
- class ModelTester:
157
- def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device):
158
- self.model = model
159
- self.device = device
160
- self.progress = progress
161
-
162
- self.post_proccess = PostProccess(vec2box, cfg.task.nms)
163
- self.save_path = progress.save_path / "images"
164
- os.makedirs(self.save_path, exist_ok=True)
165
- self.save_predict = getattr(cfg.task, "save_predict", None)
166
- self.idx2label = cfg.dataset.class_list
167
-
168
- def solve(self, dataloader: StreamDataLoader):
169
- logger.info("πŸ‘€ Start Inference!")
170
- if isinstance(self.model, torch.nn.Module):
171
- self.model.eval()
172
-
173
- if dataloader.is_stream:
174
- import cv2
175
- import numpy as np
176
-
177
- last_time = time.time()
178
- try:
179
- for idx, (images, rev_tensor, origin_frame) in enumerate(dataloader):
180
- images = images.to(self.device)
181
- rev_tensor = rev_tensor.to(self.device)
182
- with torch.no_grad():
183
- predicts = self.model(images)
184
- predicts = self.post_proccess(predicts, rev_tensor)
185
- img = draw_bboxes(origin_frame, predicts, idx2label=self.idx2label)
186
-
187
- if dataloader.is_stream:
188
- img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
189
- fps = 1 / (time.time() - last_time)
190
- cv2.putText(img, f"FPS: {fps:.2f}", (0, 15), 0, 0.5, (100, 255, 0), 1, cv2.LINE_AA)
191
- last_time = time.time()
192
- cv2.imshow("Prediction", img)
193
- if cv2.waitKey(1) & 0xFF == ord("q"):
194
- break
195
- if not self.save_predict:
196
- continue
197
- if self.save_predict != False:
198
- save_image_path = self.save_path / f"frame{idx:03d}.png"
199
- img.save(save_image_path)
200
- logger.info(f"πŸ’Ύ Saved visualize image at {save_image_path}")
201
-
202
- except (KeyboardInterrupt, Exception) as e:
203
- dataloader.stop_event.set()
204
- dataloader.stop()
205
- if isinstance(e, KeyboardInterrupt):
206
- logger.error("User Keyboard Interrupt")
207
- else:
208
- raise e
209
- dataloader.stop()
210
 
 
 
211
 
212
- class ModelValidator:
213
- def __init__(
214
- self,
215
- validation_cfg: ValidationConfig,
216
- dataset_cfg: DatasetConfig,
217
- model: YOLO,
218
- vec2box: Vec2Box,
219
- progress: ProgressLogger,
220
- device,
221
- ):
222
- self.model = model
223
- self.device = device
224
- self.progress = progress
225
 
226
- self.post_proccess = PostProccess(vec2box, validation_cfg.nms)
227
- self.json_path = self.progress.save_path / "predict.json"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
- with contextlib.redirect_stdout(io.StringIO()):
230
- # TODO: load with config file
231
- json_path, _ = locate_label_paths(Path(dataset_cfg.path), dataset_cfg.get("validation", "val"))
232
- if json_path:
233
- self.coco_gt = COCO(json_path)
234
 
235
- def solve(self, dataloader, epoch_idx=1):
236
- # logger.info("πŸ§ͺ Start Validation!")
237
- self.model.eval()
238
- predict_json, mAPs = [], defaultdict(list)
239
- self.progress.start_one_epoch(len(dataloader), task="Validate")
240
- for batch_size, images, targets, rev_tensor, img_paths in dataloader:
241
- images, targets, rev_tensor = images.to(self.device), targets.to(self.device), rev_tensor.to(self.device)
242
- with torch.no_grad():
243
- predicts = self.model(images)
244
- predicts = self.post_proccess(predicts)
245
- for idx, predict in enumerate(predicts):
246
- mAP = calculate_map(predict, targets[idx])
247
- for mAP_key, mAP_val in mAP.items():
248
- mAPs[mAP_key].append(mAP_val)
249
 
250
- avg_mAPs = {key: 100 * torch.mean(torch.stack(val)) for key, val in mAPs.items()}
251
- self.progress.one_batch(avg_mAPs)
 
 
 
 
252
 
253
- predict_json.extend(predicts_to_json(img_paths, predicts, rev_tensor))
254
- self.progress.finish_one_epoch(avg_mAPs, epoch_idx=epoch_idx)
255
- self.progress.visualize_image(images, targets, predicts, epoch_idx=epoch_idx)
 
 
256
 
257
- with open(self.json_path, "w") as f:
258
- predict_json = collect_prediction(predict_json, self.progress.local_rank)
259
- if self.progress.local_rank != 0:
260
- return
261
- json.dump(predict_json, f)
262
- if hasattr(self, "coco_gt"):
263
- self.progress.start_pycocotools()
264
- result = calculate_ap(self.coco_gt, predict_json)
265
- self.progress.finish_pycocotools(result, epoch_idx)
266
 
267
- return avg_mAPs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import time
 
2
  from pathlib import Path
 
3
 
4
+ import cv2
5
+ import numpy as np
6
+ from lightning import LightningModule
7
+ from torchmetrics.detection import MeanAveragePrecision
 
 
 
8
 
9
+ from yolo.config.config import Config
10
+ from yolo.model.yolo import create_model
11
+ from yolo.tools.data_loader import create_dataloader
12
+ from yolo.tools.drawer import draw_bboxes
13
  from yolo.tools.loss_functions import create_loss_function
14
+ from yolo.utils.bounding_box_utils import create_converter, to_metrics_format
15
+ from yolo.utils.model_utils import PostProccess, create_optimizer, create_scheduler
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
+ class BaseModel(LightningModule):
19
+ def __init__(self, cfg: Config):
20
+ super().__init__()
21
+ self.model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight)
 
 
 
 
 
 
 
 
 
22
 
23
+ def forward(self, x):
24
+ return self.model(x)
25
 
 
 
 
26
 
27
+ class ValidateModel(BaseModel):
28
+ def __init__(self, cfg: Config):
29
+ super().__init__(cfg)
30
+ self.cfg = cfg
31
+ if self.cfg.task.task == "validation":
32
+ self.validation_cfg = self.cfg.task
 
33
  else:
34
+ self.validation_cfg = self.cfg.task.validation
35
+ self.metric = MeanAveragePrecision(iou_type="bbox", box_format="xyxy")
36
+ self.metric.warn_on_many_detections = False
37
+ self.val_loader = create_dataloader(self.validation_cfg.data, self.cfg.dataset, self.validation_cfg.task)
38
+
39
+ def setup(self, stage):
40
+ self.vec2box = create_converter(
41
+ self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device
42
+ )
43
+ self.post_proccess = PostProccess(self.vec2box, self.validation_cfg.nms)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ def val_dataloader(self):
46
+ return self.val_loader
47
 
48
+ def validation_step(self, batch, batch_idx):
49
+ batch_size, images, targets, rev_tensor, img_paths = batch
50
+ predicts = self.post_proccess(self(images))
51
+ batch_metrics = self.metric(
52
+ [to_metrics_format(predict) for predict in predicts], [to_metrics_format(target) for target in targets]
53
+ )
 
 
 
 
 
 
 
54
 
55
+ self.log_dict(
56
+ {
57
+ "map": batch_metrics["map"],
58
+ "map_50": batch_metrics["map_50"],
59
+ },
60
+ on_step=True,
61
+ batch_size=batch_size,
62
+ )
63
+ return predicts
64
+
65
+ def on_validation_epoch_end(self):
66
+ epoch_metrics = self.metric.compute()
67
+ del epoch_metrics["classes"]
68
+ self.log_dict(epoch_metrics, prog_bar=True, rank_zero_only=True)
69
+ self.log_dict(
70
+ {"PyCOCO/AP @ .5:.95": epoch_metrics["map"], "PyCOCO/AP @ .5": epoch_metrics["map_50"]}, rank_zero_only=True
71
+ )
72
+ self.metric.reset()
73
+
74
+
75
+ class TrainModel(ValidateModel):
76
+ def __init__(self, cfg: Config):
77
+ super().__init__(cfg)
78
+ self.cfg = cfg
79
+ self.train_loader = create_dataloader(self.cfg.task.data, self.cfg.dataset, self.cfg.task.task)
80
+
81
+ def setup(self, stage):
82
+ super().setup(stage)
83
+ self.loss_fn = create_loss_function(self.cfg, self.vec2box)
84
+
85
+ def train_dataloader(self):
86
+ return self.train_loader
87
+
88
+ def on_train_epoch_start(self):
89
+ self.trainer.optimizers[0].next_epoch(len(self.train_loader))
90
+
91
+ def training_step(self, batch, batch_idx):
92
+ lr_dict = self.trainer.optimizers[0].next_batch()
93
+ batch_size, images, targets, *_ = batch
94
+ predicts = self(images)
95
+ aux_predicts = self.vec2box(predicts["AUX"])
96
+ main_predicts = self.vec2box(predicts["Main"])
97
+ loss, loss_item = self.loss_fn(aux_predicts, main_predicts, targets)
98
+ self.log_dict(
99
+ loss_item,
100
+ prog_bar=True,
101
+ on_epoch=True,
102
+ batch_size=batch_size,
103
+ rank_zero_only=True,
104
+ )
105
+ self.log_dict(lr_dict, prog_bar=False, logger=True, on_epoch=False, rank_zero_only=True)
106
+ return loss * batch_size
107
 
108
+ def configure_optimizers(self):
109
+ optimizer = create_optimizer(self.model, self.cfg.task.optimizer)
110
+ scheduler = create_scheduler(optimizer, self.cfg.task.scheduler)
111
+ return [optimizer], [scheduler]
 
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
+ class InferenceModel(BaseModel):
115
+ def __init__(self, cfg: Config):
116
+ super().__init__(cfg)
117
+ self.cfg = cfg
118
+ # TODO: Add FastModel
119
+ self.predict_loader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
120
 
121
+ def setup(self, stage):
122
+ self.vec2box = create_converter(
123
+ self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device
124
+ )
125
+ self.post_process = PostProcess(self.vec2box, self.cfg.task.nms)
126
 
127
+ def predict_dataloader(self):
128
+ return self.predict_loader
 
 
 
 
 
 
 
129
 
130
+ def predict_step(self, batch, batch_idx):
131
+ images, rev_tensor, origin_frame = batch
132
+ predicts = self.post_process(self(images), rev_tensor)
133
+ img = draw_bboxes(origin_frame, predicts, idx2label=self.cfg.dataset.class_list)
134
+ if getattr(self.predict_loader, "is_stream", None):
135
+ fps = self._display_stream(img)
136
+ else:
137
+ fps = None
138
+ if getattr(self.cfg.task, "save_predict", None):
139
+ self._save_image(img, batch_idx)
140
+ return img, fps
141
+
142
+ def _display_stream(self, img):
143
+ img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
144
+ fps = 1 / (time.time() - self.trainer.current_epoch_start_time)
145
+ cv2.putText(img, f"FPS: {fps:.2f}", (0, 15), 0, 0.5, (100, 255, 0), 1, cv2.LINE_AA)
146
+ cv2.imshow("Prediction", img)
147
+ if cv2.waitKey(1) & 0xFF == ord("q"):
148
+ self.trainer.should_stop = True
149
+ return fps
150
+
151
+ def _save_image(self, img, batch_idx):
152
+ save_image_path = Path(self.logger.save_dir) / f"frame{batch_idx:03d}.png"
153
+ img.save(save_image_path)
154
+ print(f"πŸ’Ύ Saved visualize image at {save_image_path}")
yolo/utils/bounding_box_utils.py CHANGED
@@ -4,17 +4,17 @@ from typing import Dict, List, Optional, Tuple, Union
4
  import torch
5
  import torch.nn.functional as F
6
  from einops import rearrange
7
- from loguru import logger
8
  from torch import Tensor, arange, tensor
9
  from torchvision.ops import batched_nms
10
 
11
  from yolo.config.config import AnchorConfig, MatcherConfig, ModelConfig, NMSConfig
12
  from yolo.model.yolo import YOLO
 
13
 
14
 
15
  def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
16
  metrics = metrics.lower()
17
- EPS = 1e-9
18
  dtype = bbox1.dtype
19
  bbox1 = bbox1.to(torch.float32)
20
  bbox2 = bbox2.to(torch.float32)
@@ -69,7 +69,8 @@ def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
69
  (bbox2[..., 2] - bbox2[..., 0]) / (bbox2[..., 3] - bbox2[..., 1] + EPS)
70
  )
71
  v = (4 / (math.pi**2)) * (arctan**2)
72
- alpha = v / (v - iou + 1 + EPS)
 
73
  # Compute CIoU
74
  ciou = diou - alpha * v
75
  return ciou.to(dtype)
@@ -129,7 +130,10 @@ def generate_anchors(image_size: List[int], strides: List[int]):
129
  shift = stride // 2
130
  h = torch.arange(0, H, stride) + shift
131
  w = torch.arange(0, W, stride) + shift
132
- anchor_h, anchor_w = torch.meshgrid(h, w, indexing="ij")
 
 
 
133
  anchor = torch.stack([anchor_w.flatten(), anchor_h.flatten()], dim=-1)
134
  anchors.append(anchor)
135
  all_anchors = torch.cat(anchors, dim=0)
@@ -207,7 +211,7 @@ class BoxMatcher:
207
  topk_masks = topk_targets > 0
208
  return topk_targets, topk_masks
209
 
210
- def filter_duplicates(self, target_matrix: Tensor):
211
  """
212
  Filter the maximum suitability target index of each anchor.
213
 
@@ -217,9 +221,11 @@ class BoxMatcher:
217
  Returns:
218
  unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
219
  """
220
- # TODO: add a assert for no target on the image
221
- unique_indices = target_matrix.argmax(dim=1)
222
- return unique_indices[..., None]
 
 
223
 
224
  def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
225
  """Matches each target to the most suitable anchor.
@@ -271,16 +277,15 @@ class BoxMatcher:
271
  topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
272
 
273
  # delete one anchor pred assign to mutliple gts
274
- unique_indices = self.filter_duplicates(topk_targets)
275
-
276
- # TODO: do we need grid_mask? Filter the valid groud truth
277
- valid_mask = (grid_mask.sum(dim=-2) * topk_mask.sum(dim=-2)).bool()
278
 
279
  align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
280
  align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
281
  align_cls = F.one_hot(align_cls, self.class_num)
282
 
283
  # normalize class ditribution
 
 
284
  max_target = target_matrix.amax(dim=-1, keepdim=True)
285
  max_iou = iou_mat.amax(dim=-1, keepdim=True)
286
  normalize_term = (target_matrix / (max_target + 1e-9)) * max_iou
@@ -295,7 +300,7 @@ class Vec2Box:
295
  self.device = device
296
 
297
  if hasattr(anchor_cfg, "strides"):
298
- logger.info(f"🈢 Found stride of model {anchor_cfg.strides}")
299
  self.strides = anchor_cfg.strides
300
  else:
301
  logger.info("🧸 Found no stride of model, performed a dummy test for auto-anchor size")
@@ -339,7 +344,7 @@ class Anc2Box:
339
  self.device = device
340
 
341
  if hasattr(anchor_cfg, "strides"):
342
- logger.info(f"🈢 Found stride of model {anchor_cfg.strides}")
343
  self.strides = anchor_cfg.strides
344
  else:
345
  logger.info("🧸 Found no stride of model, performed a dummy test for auto-anchor size")
@@ -413,7 +418,7 @@ def bbox_nms(cls_dist: Tensor, bbox: Tensor, nms_cfg: NMSConfig, confidence: Opt
413
  valid_box = bbox[valid_mask.repeat(1, 1, 4)].view(-1, 4)
414
 
415
  batch_idx, *_ = torch.where(valid_mask)
416
- nms_idx = batched_nms(valid_box, valid_cls, batch_idx, nms_cfg.min_iou)
417
  predicts_nms = []
418
  for idx in range(cls_dist.size(0)):
419
  instance_idx = nms_idx[idx == batch_idx[nms_idx]]
@@ -471,3 +476,10 @@ def calculate_map(predictions, ground_truths, iou_thresholds=arange(0.5, 1, 0.05
471
  "mAP.5:.95": torch.mean(torch.stack(aps)),
472
  }
473
  return mAP
 
 
 
 
 
 
 
 
4
  import torch
5
  import torch.nn.functional as F
6
  from einops import rearrange
 
7
  from torch import Tensor, arange, tensor
8
  from torchvision.ops import batched_nms
9
 
10
  from yolo.config.config import AnchorConfig, MatcherConfig, ModelConfig, NMSConfig
11
  from yolo.model.yolo import YOLO
12
+ from yolo.utils.logger import logger
13
 
14
 
15
  def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
16
  metrics = metrics.lower()
17
+ EPS = 1e-7
18
  dtype = bbox1.dtype
19
  bbox1 = bbox1.to(torch.float32)
20
  bbox2 = bbox2.to(torch.float32)
 
69
  (bbox2[..., 2] - bbox2[..., 0]) / (bbox2[..., 3] - bbox2[..., 1] + EPS)
70
  )
71
  v = (4 / (math.pi**2)) * (arctan**2)
72
+ with torch.no_grad():
73
+ alpha = v / (v - iou + 1 + EPS)
74
  # Compute CIoU
75
  ciou = diou - alpha * v
76
  return ciou.to(dtype)
 
130
  shift = stride // 2
131
  h = torch.arange(0, H, stride) + shift
132
  w = torch.arange(0, W, stride) + shift
133
+ if torch.__version__ >= "2.3.0":
134
+ anchor_h, anchor_w = torch.meshgrid(h, w, indexing="ij")
135
+ else:
136
+ anchor_h, anchor_w = torch.meshgrid(h, w)
137
  anchor = torch.stack([anchor_w.flatten(), anchor_h.flatten()], dim=-1)
138
  anchors.append(anchor)
139
  all_anchors = torch.cat(anchors, dim=0)
 
211
  topk_masks = topk_targets > 0
212
  return topk_targets, topk_masks
213
 
214
+ def filter_duplicates(self, target_matrix: Tensor, topk_mask: Tensor):
215
  """
216
  Filter the maximum suitability target index of each anchor.
217
 
 
221
  Returns:
222
  unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
223
  """
224
+ duplicates = (topk_mask.sum(1, keepdim=True) > 1).repeat([1, topk_mask.size(1), 1])
225
+ max_idx = F.one_hot(target_matrix.argmax(1), topk_mask.size(1)).permute(0, 2, 1)
226
+ topk_mask = torch.where(duplicates, max_idx, topk_mask)
227
+ unique_indices = topk_mask.argmax(dim=1)
228
+ return unique_indices[..., None], topk_mask.sum(1), topk_mask
229
 
230
  def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
231
  """Matches each target to the most suitable anchor.
 
277
  topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
278
 
279
  # delete one anchor pred assign to mutliple gts
280
+ unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask)
 
 
 
281
 
282
  align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
283
  align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
284
  align_cls = F.one_hot(align_cls, self.class_num)
285
 
286
  # normalize class ditribution
287
+ iou_mat *= topk_mask
288
+ target_matrix *= topk_mask
289
  max_target = target_matrix.amax(dim=-1, keepdim=True)
290
  max_iou = iou_mat.amax(dim=-1, keepdim=True)
291
  normalize_term = (target_matrix / (max_target + 1e-9)) * max_iou
 
300
  self.device = device
301
 
302
  if hasattr(anchor_cfg, "strides"):
303
+ logger.info(f":japanese_not_free_of_charge_button: Found stride of model {anchor_cfg.strides}")
304
  self.strides = anchor_cfg.strides
305
  else:
306
  logger.info("🧸 Found no stride of model, performed a dummy test for auto-anchor size")
 
344
  self.device = device
345
 
346
  if hasattr(anchor_cfg, "strides"):
347
+ logger.info(f":japanese_not_free_of_charge_button: Found stride of model {anchor_cfg.strides}")
348
  self.strides = anchor_cfg.strides
349
  else:
350
  logger.info("🧸 Found no stride of model, performed a dummy test for auto-anchor size")
 
418
  valid_box = bbox[valid_mask.repeat(1, 1, 4)].view(-1, 4)
419
 
420
  batch_idx, *_ = torch.where(valid_mask)
421
+ nms_idx = batched_nms(valid_box, valid_con, batch_idx, nms_cfg.min_iou)
422
  predicts_nms = []
423
  for idx in range(cls_dist.size(0)):
424
  instance_idx = nms_idx[idx == batch_idx[nms_idx]]
 
476
  "mAP.5:.95": torch.mean(torch.stack(aps)),
477
  }
478
  return mAP
479
+
480
+
481
+ def to_metrics_format(prediction: Tensor) -> Dict[str, Union[float, Tensor]]:
482
+ bbox = {"boxes": prediction[:, 1:5], "labels": prediction[:, 0].int()}
483
+ if prediction.size(1) == 6:
484
+ bbox["scores"] = prediction[:, 5]
485
+ return bbox
yolo/utils/dataset_utils.py CHANGED
@@ -5,9 +5,10 @@ from pathlib import Path
5
  from typing import Any, Dict, List, Optional, Tuple
6
 
7
  import numpy as np
8
- from loguru import logger
9
 
10
  from yolo.tools.data_conversion import discretize_categories
 
11
 
12
 
13
  def locate_label_paths(dataset_path: Path, phase_name: Path) -> Tuple[Path, Path]:
@@ -111,3 +112,16 @@ def scale_segmentation(
111
  seg_array_with_cat.append(scaled_flat_seg_data)
112
 
113
  return seg_array_with_cat
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from typing import Any, Dict, List, Optional, Tuple
6
 
7
  import numpy as np
8
+ import torch
9
 
10
  from yolo.tools.data_conversion import discretize_categories
11
+ from yolo.utils.logger import logger
12
 
13
 
14
  def locate_label_paths(dataset_path: Path, phase_name: Path) -> Tuple[Path, Path]:
 
112
  seg_array_with_cat.append(scaled_flat_seg_data)
113
 
114
  return seg_array_with_cat
115
+
116
+
117
+ def tensorlize(data):
118
+ img_paths, bboxes = zip(*data)
119
+ max_box = max(bbox.size(0) for bbox in bboxes)
120
+ padded_bbox_list = []
121
+ for bbox in bboxes:
122
+ padding = torch.full((max_box, 5), -1, dtype=torch.float32)
123
+ padding[: bbox.size(0)] = bbox
124
+ padded_bbox_list.append(padding)
125
+ bboxes = np.stack(padded_bbox_list)
126
+ img_paths = np.array(img_paths)
127
+ return img_paths, bboxes
yolo/utils/deploy_utils.py CHANGED
@@ -1,11 +1,11 @@
1
  from pathlib import Path
2
 
3
  import torch
4
- from loguru import logger
5
  from torch import Tensor
6
 
7
  from yolo.config.config import Config
8
  from yolo.model.yolo import create_model
 
9
 
10
 
11
  class FastModelLoader:
@@ -21,10 +21,10 @@ class FastModelLoader:
21
 
22
  def _validate_compiler(self):
23
  if self.compiler not in ["onnx", "trt", "deploy"]:
24
- logger.warning(f"⚠️ Compiler '{self.compiler}' is not supported. Using original model.")
25
  self.compiler = None
26
  if self.cfg.device == "mps" and self.compiler == "trt":
27
- logger.warning("🍎 TensorRT does not support MPS devices. Using original model.")
28
  self.compiler = None
29
 
30
  def load_model(self, device):
@@ -59,7 +59,7 @@ class FastModelLoader:
59
  providers = ["CUDAExecutionProvider"]
60
  try:
61
  ort_session = InferenceSession(self.model_path, providers=providers)
62
- logger.info("πŸš€ Using ONNX as MODEL frameworks!")
63
  except Exception as e:
64
  logger.warning(f"🈳 Error loading ONNX model: {e}")
65
  ort_session = self._create_onnx_model(providers)
@@ -79,7 +79,7 @@ class FastModelLoader:
79
  output_names=["output"],
80
  dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
81
  )
82
- logger.info(f"πŸ“₯ ONNX model saved to {self.model_path}")
83
  return InferenceSession(self.model_path, providers=providers)
84
 
85
  def _load_trt_model(self):
@@ -88,7 +88,7 @@ class FastModelLoader:
88
  try:
89
  model_trt = TRTModule()
90
  model_trt.load_state_dict(torch.load(self.model_path))
91
- logger.info("πŸš€ Using TensorRT as MODEL frameworks!")
92
  except FileNotFoundError:
93
  logger.warning(f"🈳 No found model weight at {self.model_path}")
94
  model_trt = self._create_trt_model()
@@ -102,5 +102,5 @@ class FastModelLoader:
102
  logger.info(f"♻️ Creating TensorRT model")
103
  model_trt = torch2trt(model.cuda(), [dummy_input])
104
  torch.save(model_trt.state_dict(), self.model_path)
105
- logger.info(f"πŸ“₯ TensorRT model saved to {self.model_path}")
106
  return model_trt
 
1
  from pathlib import Path
2
 
3
  import torch
 
4
  from torch import Tensor
5
 
6
  from yolo.config.config import Config
7
  from yolo.model.yolo import create_model
8
+ from yolo.utils.logger import logger
9
 
10
 
11
  class FastModelLoader:
 
21
 
22
  def _validate_compiler(self):
23
  if self.compiler not in ["onnx", "trt", "deploy"]:
24
+ logger.warning(f":warning: Compiler '{self.compiler}' is not supported. Using original model.")
25
  self.compiler = None
26
  if self.cfg.device == "mps" and self.compiler == "trt":
27
+ logger.warning(":red_apple: TensorRT does not support MPS devices. Using original model.")
28
  self.compiler = None
29
 
30
  def load_model(self, device):
 
59
  providers = ["CUDAExecutionProvider"]
60
  try:
61
  ort_session = InferenceSession(self.model_path, providers=providers)
62
+ logger.info(":rocket: Using ONNX as MODEL frameworks!")
63
  except Exception as e:
64
  logger.warning(f"🈳 Error loading ONNX model: {e}")
65
  ort_session = self._create_onnx_model(providers)
 
79
  output_names=["output"],
80
  dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
81
  )
82
+ logger.info(f":inbox_tray: ONNX model saved to {self.model_path}")
83
  return InferenceSession(self.model_path, providers=providers)
84
 
85
  def _load_trt_model(self):
 
88
  try:
89
  model_trt = TRTModule()
90
  model_trt.load_state_dict(torch.load(self.model_path))
91
+ logger.info(":rocket: Using TensorRT as MODEL frameworks!")
92
  except FileNotFoundError:
93
  logger.warning(f"🈳 No found model weight at {self.model_path}")
94
  model_trt = self._create_trt_model()
 
102
  logger.info(f"♻️ Creating TensorRT model")
103
  model_trt = torch2trt(model.cuda(), [dummy_input])
104
  torch.save(model_trt.state_dict(), self.model_path)
105
+ logger.info(f":inbox_tray: TensorRT model saved to {self.model_path}")
106
  return model_trt
yolo/utils/logger.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from lightning.pytorch.utilities.rank_zero import rank_zero_only
4
+ from rich.console import Console
5
+ from rich.logging import RichHandler
6
+
7
+ logger = logging.getLogger("yolo")
8
+ logger.setLevel(logging.DEBUG)
9
+ logger.propagate = False
10
+ if rank_zero_only.rank == 0 and not logger.hasHandlers():
11
+ logger.addHandler(RichHandler(console=Console(), show_level=True, show_path=True, show_time=True, markup=True))
yolo/utils/logging_utils.py CHANGED
@@ -11,55 +11,39 @@ Example:
11
  custom_logger()
12
  """
13
 
14
- import os
15
- import random
16
- import sys
17
  from collections import deque
 
18
  from pathlib import Path
19
  from typing import Any, Dict, List, Optional, Tuple, Union
20
 
21
  import numpy as np
22
  import torch
23
  import wandb
24
- import wandb.errors.term
25
- from loguru import logger
 
 
 
26
  from omegaconf import ListConfig
 
27
  from rich.console import Console, Group
28
- from rich.progress import (
29
- BarColumn,
30
- Progress,
31
- SpinnerColumn,
32
- TextColumn,
33
- TimeRemainingColumn,
34
- )
35
  from rich.table import Table
 
36
  from torch import Tensor
37
  from torch.nn import ModuleList
38
- from torch.optim import Optimizer
39
- from torchvision.transforms.functional import pil_to_tensor
40
 
41
  from yolo.config.config import Config, YOLOLayer
42
  from yolo.model.yolo import YOLO
43
- from yolo.tools.drawer import draw_bboxes
44
  from yolo.utils.solver_utils import make_ap_table
45
 
46
 
47
- def custom_logger(quite: bool = False):
48
- logger.remove()
49
- if quite:
50
- return
51
- logger.add(
52
- sys.stderr,
53
- colorize=True,
54
- format="<fg #003385>[{time:MM/DD HH:mm:ss}]</> <level>{level: ^8}</level>| <level>{message}</level>",
55
- )
56
-
57
-
58
  # TODO: should be moved to correct position
59
  def set_seed(seed):
60
- random.seed(seed)
61
- np.random.seed(seed)
62
- torch.manual_seed(seed)
63
  if torch.cuda.is_available():
64
  torch.cuda.manual_seed(seed)
65
  torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
@@ -67,189 +51,220 @@ def set_seed(seed):
67
  torch.backends.cudnn.benchmark = False
68
 
69
 
70
- class ProgressLogger(Progress):
71
- def __init__(self, cfg: Config, exp_name: str, *args, **kwargs):
72
- set_seed(cfg.lucky_number)
73
- self.local_rank = int(os.getenv("LOCAL_RANK", "0"))
74
- self.quite_mode = self.local_rank or getattr(cfg, "quite", False)
75
- custom_logger(self.quite_mode)
76
- self.save_path = validate_log_directory(cfg, exp_name=cfg.name)
77
-
78
- progress_bar = (
79
- SpinnerColumn(),
80
- TextColumn("[progress.description]{task.description}"),
81
- BarColumn(bar_width=None),
82
- TextColumn("{task.completed:.0f}/{task.total:.0f}"),
83
- TimeRemainingColumn(),
84
- )
85
- self.ap_table = Table()
86
- # TODO: load maxlen by config files
87
- self.ap_past_list = deque(maxlen=5)
88
- self.last_result = 0
89
- super().__init__(*args, *progress_bar, **kwargs)
90
-
91
- self.use_wandb = cfg.use_wandb
92
- if self.use_wandb and self.local_rank == 0:
93
- wandb.errors.term._log = custom_wandb_log
94
- self.wandb = wandb.init(
95
- project="YOLO", resume="allow", mode="online", dir=self.save_path, id=None, name=exp_name
96
- )
97
 
98
- self.use_tensorboard = cfg.use_tensorboard
99
- if self.use_tensorboard and self.local_rank == 0:
100
- from torch.utils.tensorboard import SummaryWriter
101
 
102
- self.tb_writer = SummaryWriter(log_dir=self.save_path / "tensorboard")
103
- logger.opt(colors=True).info(f"πŸ“ Enable TensorBoard locally at <blue><u>http://localhost:6006</></>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- def rank_check(logging_function):
106
- def wrapper(self, *args, **kwargs):
107
- if getattr(self, "local_rank", 0) != 0:
108
- return
109
- return logging_function(self, *args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
- return wrapper
 
112
 
113
- def get_renderable(self):
114
- renderable = Group(*self.get_renderables(), self.ap_table)
115
- return renderable
 
 
116
 
117
- @rank_check
118
- def start_train(self, num_epochs: int):
119
- self.task_epoch = self.add_task(f"[cyan]Start Training {num_epochs} epochs", total=num_epochs)
120
- self.update(self.task_epoch, advance=-0.5)
121
-
122
- @rank_check
123
- def start_one_epoch(
124
- self, num_batches: int, task: str = "Train", optimizer: Optimizer = None, epoch_idx: int = None
125
- ):
126
- self.num_batches = num_batches
127
- self.task = task
128
- if hasattr(self, "task_epoch"):
129
- self.update(self.task_epoch, description=f"[cyan] Preparing Data")
130
-
131
- if optimizer is not None:
132
- lr_values = [params["lr"] for params in optimizer.param_groups]
133
- lr_names = ["Learning Rate/bias", "Learning Rate/norm", "Learning Rate/conv"]
134
- if self.use_wandb:
135
- for lr_name, lr_value in zip(lr_names, lr_values):
136
- self.wandb.log({lr_name: lr_value}, step=epoch_idx)
137
-
138
- if self.use_tensorboard:
139
- for lr_name, lr_value in zip(lr_names, lr_values):
140
- self.tb_writer.add_scalar(lr_name, lr_value, global_step=epoch_idx)
141
-
142
- self.batch_task = self.add_task(f"[green] Phase: {task}", total=num_batches)
143
-
144
- @rank_check
145
- def one_batch(self, batch_info: Dict[str, Tensor] = None):
146
- epoch_descript = "[cyan]" + self.task + "[white] |"
147
- batch_descript = "|"
148
- if self.task == "Train":
149
- self.update(self.task_epoch, advance=1 / self.num_batches)
150
- for info_name, info_val in batch_info.items():
151
- epoch_descript += f"{info_name: ^9}|"
152
- batch_descript += f" {info_val:2.2f} |"
153
- self.update(self.batch_task, advance=1, description=f"[green]{self.task} [white]{batch_descript}")
154
- if hasattr(self, "task_epoch"):
155
- self.update(self.task_epoch, description=epoch_descript)
156
-
157
- @rank_check
158
- def finish_one_epoch(self, batch_info: Dict[str, Any] = None, epoch_idx: int = -1):
159
- if self.task == "Train":
160
- prefix = "Loss"
161
- elif self.task == "Validate":
162
- prefix = "Metrics"
163
- batch_info = {f"{prefix}/{key}": value for key, value in batch_info.items()}
164
- if self.use_wandb:
165
- self.wandb.log(batch_info, step=epoch_idx)
166
- if self.use_tensorboard:
167
- for key, value in batch_info.items():
168
- self.tb_writer.add_scalar(key, value, epoch_idx)
169
-
170
- self.remove_task(self.batch_task)
171
-
172
- @rank_check
173
- def visualize_image(
174
- self,
175
- images: Optional[Tensor] = None,
176
- ground_truth: Optional[Tensor] = None,
177
- prediction: Optional[Union[List[Tensor], Tensor]] = None,
178
- epoch_idx: int = 0,
179
- ) -> None:
180
- """
181
- Upload the ground truth bounding boxes, predicted bounding boxes, and the original image to wandb or TensorBoard.
182
-
183
- Args:
184
- images (Optional[Tensor]): Tensor of images with shape (BZ, 3, 640, 640).
185
- ground_truth (Optional[Tensor]): Ground truth bounding boxes with shape (BZ, N, 5) or (N, 5). Defaults to None.
186
- prediction (prediction: Optional[Union[List[Tensor], Tensor]]): List of predicted bounding boxes with shape (N, 6) or (N, 6). Defaults to None.
187
- epoch_idx (int): Current epoch index. Defaults to 0.
188
- """
189
- if images is not None:
190
- images = images[0] if images.ndim == 4 else images
191
- if self.use_wandb:
192
- wandb.log({"Input Image": wandb.Image(images)}, step=epoch_idx)
193
- if self.use_tensorboard:
194
- self.tb_writer.add_image("Media/Input Image", images, 1)
195
-
196
- if ground_truth is not None:
197
- gt_boxes = ground_truth[0] if ground_truth.ndim == 3 else ground_truth
198
- if self.use_wandb:
199
- wandb.log(
200
- {"Ground Truth": wandb.Image(images, boxes={"predictions": {"box_data": log_bbox(gt_boxes)}})},
201
- step=epoch_idx,
202
- )
203
- if self.use_tensorboard:
204
- self.tb_writer.add_image("Media/Ground Truth", pil_to_tensor(draw_bboxes(images, gt_boxes)), epoch_idx)
205
-
206
- if prediction is not None:
207
- pred_boxes = prediction[0] if isinstance(prediction, list) else prediction
208
- if self.use_wandb:
209
- wandb.log(
210
- {"Prediction": wandb.Image(images, boxes={"predictions": {"box_data": log_bbox(pred_boxes)}})},
211
- step=epoch_idx,
212
- )
213
- if self.use_tensorboard:
214
- self.tb_writer.add_image("Media/Prediction", pil_to_tensor(draw_bboxes(images, pred_boxes)), epoch_idx)
215
-
216
- @rank_check
217
- def start_pycocotools(self):
218
- self.batch_task = self.add_task("[green]Run pycocotools", total=1)
219
-
220
- @rank_check
221
- def finish_pycocotools(self, result, epoch_idx=-1):
222
- ap_table, ap_main = make_ap_table(result * 100, self.ap_past_list, self.last_result, epoch_idx)
223
- self.last_result = np.maximum(result, self.last_result)
224
- self.ap_past_list.append((epoch_idx, ap_main))
225
- self.ap_table = ap_table
226
-
227
- if self.use_wandb:
228
- self.wandb.log({"PyCOCO/AP @ .5:.95": ap_main[2], "PyCOCO/AP @ .5": ap_main[5]})
229
- if self.use_tensorboard:
230
- # TODO: waiting torch bugs fix, https://github.com/pytorch/pytorch/issues/32651
231
- self.tb_writer.add_scalar("PyCOCO/AP @ .5:.95", ap_main[2], epoch_idx)
232
- self.tb_writer.add_scalar("PyCOCO/AP @ .5", ap_main[5], epoch_idx)
233
-
234
- self.update(self.batch_task, advance=1)
235
- self.refresh()
236
- self.remove_task(self.batch_task)
237
 
238
- @rank_check
239
- def finish_train(self):
240
- self.remove_task(self.task_epoch)
241
- self.stop()
242
- if self.use_wandb:
243
- self.wandb.finish()
244
- if self.use_tensorboard:
245
- self.tb_writer.close()
246
 
 
 
 
 
 
 
 
 
247
 
248
- def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=True, silent=False):
249
- if silent:
250
- return
251
- for line in string.split("\n"):
252
- logger.opt(raw=not newline, colors=True).info("🌐 " + line)
253
 
254
 
255
  def log_model_structure(model: Union[ModuleList, YOLOLayer, YOLO]):
@@ -279,6 +294,7 @@ def log_model_structure(model: Union[ModuleList, YOLOLayer, YOLO]):
279
  console.print(table)
280
 
281
 
 
282
  def validate_log_directory(cfg: Config, exp_name: str) -> Path:
283
  base_path = Path(cfg.out_path, cfg.task.task)
284
  save_path = base_path / exp_name
@@ -296,8 +312,8 @@ def validate_log_directory(cfg: Config, exp_name: str) -> Path:
296
  )
297
 
298
  save_path.mkdir(parents=True, exist_ok=True)
299
- logger.opt(colors=True).info(f"πŸ“„ Created log folder: <u><fg #808080>{save_path}</></>")
300
- logger.add(save_path / "output.log", mode="w", backtrace=True, diagnose=True)
301
  return save_path
302
 
303
 
@@ -332,4 +348,4 @@ def log_bbox(
332
  bbox_entry["scores"] = {"confidence": conf[0]}
333
  bbox_list.append(bbox_entry)
334
 
335
- return bbox_list
 
11
  custom_logger()
12
  """
13
 
14
+ import logging
 
 
15
  from collections import deque
16
+ from logging import FileHandler
17
  from pathlib import Path
18
  from typing import Any, Dict, List, Optional, Tuple, Union
19
 
20
  import numpy as np
21
  import torch
22
  import wandb
23
+ from lightning import LightningModule, Trainer, seed_everything
24
+ from lightning.pytorch.callbacks import Callback, RichModelSummary, RichProgressBar
25
+ from lightning.pytorch.callbacks.progress.rich_progress import CustomProgress
26
+ from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
27
+ from lightning.pytorch.utilities import rank_zero_only
28
  from omegaconf import ListConfig
29
+ from rich import get_console, reconfigure
30
  from rich.console import Console, Group
31
+ from rich.logging import RichHandler
 
 
 
 
 
 
32
  from rich.table import Table
33
+ from rich.text import Text
34
  from torch import Tensor
35
  from torch.nn import ModuleList
36
+ from typing_extensions import override
 
37
 
38
  from yolo.config.config import Config, YOLOLayer
39
  from yolo.model.yolo import YOLO
40
+ from yolo.utils.logger import logger
41
  from yolo.utils.solver_utils import make_ap_table
42
 
43
 
 
 
 
 
 
 
 
 
 
 
 
44
  # TODO: should be moved to correct position
45
  def set_seed(seed):
46
+ seed_everything(seed)
 
 
47
  if torch.cuda.is_available():
48
  torch.cuda.manual_seed(seed)
49
  torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
 
51
  torch.backends.cudnn.benchmark = False
52
 
53
 
54
+ class YOLOCustomProgress(CustomProgress):
55
+ def get_renderable(self):
56
+ renderable = Group(*self.get_renderables())
57
+ if hasattr(self, "table"):
58
+ renderable = Group(*self.get_renderables(), self.table)
59
+ return renderable
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
 
 
 
61
 
62
+ class YOLORichProgressBar(RichProgressBar):
63
+ @override
64
+ @rank_zero_only
65
+ def _init_progress(self, trainer: "Trainer") -> None:
66
+ if self.is_enabled and (self.progress is None or self._progress_stopped):
67
+ self._reset_progress_bar_ids()
68
+ reconfigure(**self._console_kwargs)
69
+ self._console = Console()
70
+ self._console.clear_live()
71
+ self.progress = YOLOCustomProgress(
72
+ *self.configure_columns(trainer),
73
+ auto_refresh=False,
74
+ disable=self.is_disabled,
75
+ console=self._console,
76
+ )
77
+ self.progress.start()
78
+
79
+ self._progress_stopped = False
80
+
81
+ self.max_result = 0
82
+ self.past_results = deque(maxlen=5)
83
+ self.progress.table = Table()
84
+
85
+ @override
86
+ def _get_train_description(self, current_epoch: int) -> str:
87
+ return Text("[cyan]Train [white]|")
88
+
89
+ @override
90
+ @rank_zero_only
91
+ def on_train_start(self, trainer, pl_module):
92
+ self._init_progress(trainer)
93
+ num_epochs = trainer.max_epochs - 1
94
+ self.task_epoch = self._add_task(
95
+ total_batches=num_epochs,
96
+ description=f"[cyan]Start Training {num_epochs} epochs",
97
+ )
98
+ self.max_result = 0
99
+ self.past_results.clear()
100
+ self.progress.update(self.task_epoch, advance=-0.5)
101
+
102
+ @override
103
+ @rank_zero_only
104
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch: Any, batch_idx: int):
105
+ self._update(self.train_progress_bar_id, batch_idx + 1)
106
+ self._update_metrics(trainer, pl_module)
107
+ epoch_descript = "[cyan]Train [white]|"
108
+ batch_descript = "[green]Train [white]|"
109
+ metrics = self.get_metrics(trainer, pl_module)
110
+ metrics.pop("v_num")
111
+ for metrics_name, metrics_val in metrics.items():
112
+ if "Loss_step" in metrics_name:
113
+ epoch_descript += f"{metrics_name.removesuffix('_step').split('/')[1]: ^9}|"
114
+ batch_descript += f" {metrics_val:2.2f} |"
115
+
116
+ self.progress.update(self.task_epoch, advance=1 / self.total_train_batches, description=epoch_descript)
117
+ self.progress.update(self.train_progress_bar_id, description=batch_descript)
118
+ self.refresh()
119
 
120
+ @override
121
+ @rank_zero_only
122
+ def on_train_end(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
123
+ self._update_metrics(trainer, pl_module)
124
+ self.progress.remove_task(self.train_progress_bar_id)
125
+ self.train_progress_bar_id = None
126
+
127
+ @override
128
+ @rank_zero_only
129
+ def on_validation_end(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
130
+ if trainer.state.fn == "fit":
131
+ self._update_metrics(trainer, pl_module)
132
+ self.reset_dataloader_idx_tracker()
133
+ all_metrics = self.get_metrics(trainer, pl_module)
134
+
135
+ ap_ar_list = [
136
+ key
137
+ for key in all_metrics.keys()
138
+ if key.startswith(("map", "mar")) and not key.endswith(("_step", "_epoch"))
139
+ ]
140
+ score = np.array([all_metrics[key] for key in ap_ar_list]) * 100
141
+
142
+ self.progress.table, ap_main = make_ap_table(score, self.past_results, self.max_result, trainer.current_epoch)
143
+ self.max_result = np.maximum(score, self.max_result)
144
+ self.past_results.append((trainer.current_epoch, ap_main))
145
+
146
+ @override
147
+ def refresh(self) -> None:
148
+ if self.progress:
149
+ self.progress.refresh()
150
+
151
+ @property
152
+ def validation_description(self) -> str:
153
+ return "[green]Validation"
154
+
155
+
156
+ class YOLORichModelSummary(RichModelSummary):
157
+ @staticmethod
158
+ @override
159
+ def summarize(
160
+ summary_data: List[Tuple[str, List[str]]],
161
+ total_parameters: int,
162
+ trainable_parameters: int,
163
+ model_size: float,
164
+ total_training_modes: Dict[str, int],
165
+ **summarize_kwargs: Any,
166
+ ) -> None:
167
+ from lightning.pytorch.utilities.model_summary import get_human_readable_count
168
+
169
+ console = get_console()
170
+
171
+ header_style: str = summarize_kwargs.get("header_style", "bold magenta")
172
+ table = Table(header_style=header_style)
173
+ table.add_column(" ", style="dim")
174
+ table.add_column("Name", justify="left", no_wrap=True)
175
+ table.add_column("Type")
176
+ table.add_column("Params", justify="right")
177
+ table.add_column("Mode")
178
+
179
+ column_names = list(zip(*summary_data))[0]
180
+
181
+ for column_name in ["In sizes", "Out sizes"]:
182
+ if column_name in column_names:
183
+ table.add_column(column_name, justify="right", style="white")
184
+
185
+ rows = list(zip(*(arr[1] for arr in summary_data)))
186
+ for row in rows:
187
+ table.add_row(*row)
188
+
189
+ console.print(table)
190
+
191
+ parameters = []
192
+ for param in [trainable_parameters, total_parameters - trainable_parameters, total_parameters, model_size]:
193
+ parameters.append("{:<{}}".format(get_human_readable_count(int(param)), 10))
194
+
195
+ grid = Table(header_style=header_style)
196
+ table.add_column(" ", style="dim")
197
+ grid.add_column("[bold]Attributes[/]")
198
+ grid.add_column("Value")
199
+
200
+ grid.add_row("[bold]Trainable params[/]", f"{parameters[0]}")
201
+ grid.add_row("[bold]Non-trainable params[/]", f"{parameters[1]}")
202
+ grid.add_row("[bold]Total params[/]", f"{parameters[2]}")
203
+ grid.add_row("[bold]Total estimated model params size (MB)[/]", f"{parameters[3]}")
204
+ grid.add_row("[bold]Modules in train mode[/]", f"{total_training_modes['train']}")
205
+ grid.add_row("[bold]Modules in eval mode[/]", f"{total_training_modes['eval']}")
206
+
207
+ console.print(grid)
208
+
209
+
210
+ class ImageLogger(Callback):
211
+ def on_validation_batch_end(self, trainer: Trainer, pl_module, outputs, batch, batch_idx) -> None:
212
+ if batch_idx != 0:
213
+ return
214
+ batch_size, images, targets, rev_tensor, img_paths = batch
215
+ gt_boxes = targets[0] if targets.ndim == 3 else targets
216
+ pred_boxes = outputs[0] if isinstance(outputs, list) else outputs
217
+ images = [images[0]]
218
+ step = trainer.current_epoch
219
+ for logger in trainer.loggers:
220
+ if isinstance(logger, WandbLogger):
221
+ logger.log_image("Input Image", images, step=step)
222
+ logger.log_image("Ground Truth", images, step=step, boxes=[log_bbox(gt_boxes)])
223
+ logger.log_image("Prediction", images, step=step, boxes=[log_bbox(pred_boxes)])
224
+
225
+
226
+ def setup_logger(logger_name):
227
+ class EmojiFormatter(logging.Formatter):
228
+ def format(self, record, emoji=":high_voltage:"):
229
+ return f"{emoji} {super().format(record)}"
230
+
231
+ rich_handler = RichHandler(markup=True)
232
+ rich_handler.setFormatter(EmojiFormatter("%(message)s"))
233
+ rich_logger = logging.getLogger(logger_name)
234
+ if rich_logger:
235
+ rich_logger.handlers.clear()
236
+ rich_logger.addHandler(rich_handler)
237
+
238
+
239
+ def setup(cfg: Config):
240
+ # seed_everything(cfg.lucky_number)
241
+ if hasattr(cfg, "quite"):
242
+ logger.removeHandler("YOLO_logger")
243
+ return
244
 
245
+ setup_logger("lightning.fabric")
246
+ setup_logger("lightning.pytorch")
247
 
248
+ def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=True, silent=False):
249
+ if silent:
250
+ return
251
+ for line in string.split("\n"):
252
+ logger.info(Text.from_ansi(":globe_with_meridians: " + line))
253
 
254
+ wandb.errors.term._log = custom_wandb_log
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
+ save_path = validate_log_directory(cfg, cfg.name)
 
 
 
 
 
 
 
257
 
258
+ progress, loggers = [], []
259
+ progress.append(YOLORichProgressBar())
260
+ progress.append(YOLORichModelSummary())
261
+ progress.append(ImageLogger())
262
+ if cfg.use_tensorboard:
263
+ loggers.append(TensorBoardLogger(log_graph="all", save_dir=save_path))
264
+ if cfg.use_wandb:
265
+ loggers.append(WandbLogger(project="YOLO", name=cfg.name, save_dir=save_path, id=None))
266
 
267
+ return progress, loggers
 
 
 
 
268
 
269
 
270
  def log_model_structure(model: Union[ModuleList, YOLOLayer, YOLO]):
 
294
  console.print(table)
295
 
296
 
297
+ @rank_zero_only
298
  def validate_log_directory(cfg: Config, exp_name: str) -> Path:
299
  base_path = Path(cfg.out_path, cfg.task.task)
300
  save_path = base_path / exp_name
 
312
  )
313
 
314
  save_path.mkdir(parents=True, exist_ok=True)
315
+ logger.info(f"πŸ“„ Created log folder: [blue b u]{save_path}[/]")
316
+ logger.addHandler(FileHandler(save_path / "output.log"))
317
  return save_path
318
 
319
 
 
348
  bbox_entry["scores"] = {"confidence": conf[0]}
349
  bbox_list.append(bbox_entry)
350
 
351
+ return {"predictions": {"box_data": bbox_list}}
yolo/utils/model_utils.py CHANGED
@@ -4,7 +4,6 @@ from typing import List, Optional, Type, Union
4
 
5
  import torch
6
  import torch.distributed as dist
7
- from loguru import logger
8
  from omegaconf import ListConfig
9
  from torch import Tensor
10
  from torch.optim import Optimizer
@@ -13,6 +12,7 @@ from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler
13
  from yolo.config.config import IDX_TO_ID, NMSConfig, OptimizerConfig, SchedulerConfig
14
  from yolo.model.yolo import YOLO
15
  from yolo.utils.bounding_box_utils import bbox_nms, transform_bbox
 
16
 
17
 
18
  class ExponentialMovingAverage:
@@ -52,9 +52,9 @@ def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
52
  conv_params = [p for name, p in model.named_parameters() if "weight" in name and "bn" not in name]
53
 
54
  model_parameters = [
55
- {"params": bias_params, "weight_decay": 0},
56
- {"params": conv_params},
57
- {"params": norm_params, "weight_decay": 0},
58
  ]
59
 
60
  def next_epoch(self, batch_num):
@@ -65,12 +65,16 @@ def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
65
 
66
  def next_batch(self):
67
  self.batch_idx += 1
 
68
  for lr_idx, param_group in enumerate(self.param_groups):
69
  min_lr, max_lr = self.min_lr[lr_idx], self.max_lr[lr_idx]
70
  param_group["lr"] = min_lr + (self.batch_idx) * (max_lr - min_lr) / self.batch_num
 
 
71
 
72
  optimizer_class.next_batch = next_batch
73
  optimizer_class.next_epoch = next_epoch
 
74
  optimizer = optimizer_class(model_parameters, **optim_cfg.args)
75
  optimizer.max_lr = [0.1, 0, 0]
76
  return optimizer
@@ -168,6 +172,7 @@ def predicts_to_json(img_paths, predicts, rev_tensor):
168
  batch_json = []
169
  for img_path, bboxes, box_reverse in zip(img_paths, predicts, rev_tensor):
170
  scale, shift = box_reverse.split([1, 4])
 
171
  bboxes[:, 1:5] = (bboxes[:, 1:5] - shift[None]) / scale[None]
172
  bboxes[:, 1:5] = transform_bbox(bboxes[:, 1:5], "xyxy -> xywh")
173
  for cls, *pos, conf in bboxes:
 
4
 
5
  import torch
6
  import torch.distributed as dist
 
7
  from omegaconf import ListConfig
8
  from torch import Tensor
9
  from torch.optim import Optimizer
 
12
  from yolo.config.config import IDX_TO_ID, NMSConfig, OptimizerConfig, SchedulerConfig
13
  from yolo.model.yolo import YOLO
14
  from yolo.utils.bounding_box_utils import bbox_nms, transform_bbox
15
+ from yolo.utils.logger import logger
16
 
17
 
18
  class ExponentialMovingAverage:
 
52
  conv_params = [p for name, p in model.named_parameters() if "weight" in name and "bn" not in name]
53
 
54
  model_parameters = [
55
+ {"params": bias_params, "momentum": 0.8, "weight_decay": 0},
56
+ {"params": conv_params, "momentum": 0.8},
57
+ {"params": norm_params, "momentum": 0.8, "weight_decay": 0},
58
  ]
59
 
60
  def next_epoch(self, batch_num):
 
65
 
66
  def next_batch(self):
67
  self.batch_idx += 1
68
+ lr_dict = dict()
69
  for lr_idx, param_group in enumerate(self.param_groups):
70
  min_lr, max_lr = self.min_lr[lr_idx], self.max_lr[lr_idx]
71
  param_group["lr"] = min_lr + (self.batch_idx) * (max_lr - min_lr) / self.batch_num
72
+ lr_dict[f"LR/{lr_idx}"] = param_group["lr"]
73
+ return lr_dict
74
 
75
  optimizer_class.next_batch = next_batch
76
  optimizer_class.next_epoch = next_epoch
77
+
78
  optimizer = optimizer_class(model_parameters, **optim_cfg.args)
79
  optimizer.max_lr = [0.1, 0, 0]
80
  return optimizer
 
172
  batch_json = []
173
  for img_path, bboxes, box_reverse in zip(img_paths, predicts, rev_tensor):
174
  scale, shift = box_reverse.split([1, 4])
175
+ bboxes = bboxes.clone()
176
  bboxes[:, 1:5] = (bboxes[:, 1:5] - shift[None]) / scale[None]
177
  bboxes[:, 1:5] = transform_bbox(bboxes[:, 1:5], "xyxy -> xywh")
178
  for cls, *pos, conf in bboxes:
yolo/utils/solver_utils.py CHANGED
@@ -1,5 +1,6 @@
1
  import contextlib
2
  import io
 
3
 
4
  import numpy as np
5
  from pycocotools.coco import COCO
@@ -17,7 +18,7 @@ def calculate_ap(coco_gt: COCO, pd_path):
17
  return coco_eval.stats
18
 
19
 
20
- def make_ap_table(score, past_result=[], last_score=None, epoch=-1):
21
  ap_table = Table()
22
  ap_table.add_column("Epoch", justify="center", style="white", width=5)
23
  ap_table.add_column("Avg. Precision", justify="left", style="cyan")
@@ -30,7 +31,7 @@ def make_ap_table(score, past_result=[], last_score=None, epoch=-1):
30
  if past_result:
31
  ap_table.add_row()
32
 
33
- color = np.where(last_score <= score, "[green]", "[red]")
34
 
35
  this_ap = ("AP @ .5:.95", color[0], score[0], "AP @ .5", color[1], score[1])
36
  metrics = [
 
1
  import contextlib
2
  import io
3
+ from typing import Dict
4
 
5
  import numpy as np
6
  from pycocotools.coco import COCO
 
18
  return coco_eval.stats
19
 
20
 
21
+ def make_ap_table(score: Dict[str, float], past_result=[], max_result=None, epoch=-1):
22
  ap_table = Table()
23
  ap_table.add_column("Epoch", justify="center", style="white", width=5)
24
  ap_table.add_column("Avg. Precision", justify="left", style="cyan")
 
31
  if past_result:
32
  ap_table.add_row()
33
 
34
+ color = np.where(max_result <= score, "[green]", "[red]")
35
 
36
  this_ap = ("AP @ .5:.95", color[0], score[0], "AP @ .5", color[1], score[1])
37
  metrics = [