henry000 commited on
Commit
b5fa3f1
·
1 Parent(s): 9276557

🧱 [Update] config files struct, make it clearly

Browse files
yolo/config/README.md DELETED
@@ -1 +0,0 @@
1
- model configuration
 
 
yolo/config/config.py CHANGED
@@ -1,5 +1,5 @@
1
  from dataclasses import dataclass
2
- from typing import Dict, List, Union
3
 
4
  from torch import nn
5
 
@@ -10,25 +10,47 @@ class AnchorConfig:
10
  strides: List[int]
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  @dataclass
14
  class Model:
15
  anchor: AnchorConfig
16
- model: Dict[str, List[Dict[str, Union[Dict, List, int]]]]
17
 
18
 
19
  @dataclass
20
- class Download:
21
- auto: bool
 
 
 
 
 
 
 
 
 
 
22
  path: str
 
23
 
24
 
25
  @dataclass
26
- class DataLoaderConfig:
27
- batch_size: int
28
- class_num: int
29
- image_size: List[int]
30
  shuffle: bool
 
31
  pin_memory: bool
 
32
 
33
 
34
  @dataclass
@@ -44,16 +66,24 @@ class OptimizerConfig:
44
 
45
 
46
  @dataclass
47
- class SchedulerArgs:
48
- step_size: int
49
- gamma: float
 
 
 
 
 
 
 
 
50
 
51
 
52
  @dataclass
53
  class SchedulerConfig:
54
  type: str
55
- args: SchedulerArgs
56
- warmup: Dict[str, Union[str, int, float]]
57
 
58
 
59
  @dataclass
@@ -62,66 +92,48 @@ class EMAConfig:
62
  decay: float
63
 
64
 
65
- @dataclass
66
- class MatcherConfig:
67
- iou: str
68
- topk: int
69
- factor: Dict[str, int]
70
-
71
-
72
- @dataclass
73
- class LossConfig:
74
- objective: List[List]
75
- aux: Union[bool, float]
76
- matcher: MatcherConfig
77
-
78
-
79
  @dataclass
80
  class TrainConfig:
 
 
81
  epoch: int
 
82
  optimizer: OptimizerConfig
 
83
  scheduler: SchedulerConfig
84
  ema: EMAConfig
85
- loss: LossConfig
86
 
87
 
88
  @dataclass
89
- class GeneralConfig:
90
- out_path: str
91
- task: str
92
- device: Union[str, int, List[int]]
93
- cpu_num: int
94
- use_wandb: bool
95
- lucky_number: 10
96
- exist_ok: bool
97
- resume_train: bool
98
- use_TensorBoard: bool
99
 
100
 
101
  @dataclass
102
- class HyperConfig:
103
- general: GeneralConfig
104
- data: DataLoaderConfig
105
- train: TrainConfig
106
 
107
 
108
  @dataclass
109
- class Dataset:
110
- file_name: str
111
- num_files: int
 
112
 
 
 
113
 
114
- @dataclass
115
- class Datasets:
116
- base_url: str
117
- images: Dict[str, Dataset]
118
 
 
 
119
 
120
- @dataclass
121
- class Download:
122
- auto: bool
123
- save_path: str
124
- datasets: Datasets
125
 
126
 
127
  @dataclass
@@ -134,11 +146,3 @@ class YOLOLayer(nn.Module):
134
 
135
  def __post_init__(self):
136
  super().__init__()
137
-
138
-
139
- @dataclass
140
- class Config:
141
- model: Model
142
- download: Download
143
- hyper: HyperConfig
144
- name: str
 
1
  from dataclasses import dataclass
2
+ from typing import Any, Dict, List, Optional, Union
3
 
4
  from torch import nn
5
 
 
10
  strides: List[int]
11
 
12
 
13
+ @dataclass
14
+ class LayerConfg:
15
+ args: Dict
16
+ source: Union[int, str, List[int]]
17
+ tags: str
18
+
19
+
20
+ @dataclass
21
+ class BlockConfig:
22
+ block: List[Dict[str, LayerConfg]]
23
+
24
+
25
  @dataclass
26
  class Model:
27
  anchor: AnchorConfig
28
+ model: Dict[str, BlockConfig]
29
 
30
 
31
  @dataclass
32
+ class DownloadDetail:
33
+ url: str
34
+ file_size: int
35
+
36
+
37
+ @dataclass
38
+ class DownloadOptions:
39
+ details: Dict[str, DownloadDetail]
40
+
41
+
42
+ @dataclass
43
+ class DatasetConfig:
44
  path: str
45
+ auto_download: Optional[DownloadOptions]
46
 
47
 
48
  @dataclass
49
+ class DataConfig:
 
 
 
50
  shuffle: bool
51
+ batch_size: int
52
  pin_memory: bool
53
+ data_augment: Dict[str, int]
54
 
55
 
56
  @dataclass
 
66
 
67
 
68
  @dataclass
69
+ class MatcherConfig:
70
+ iou: str
71
+ topk: int
72
+ factor: Dict[str, int]
73
+
74
+
75
+ @dataclass
76
+ class LossConfig:
77
+ objective: Dict[str, int]
78
+ aux: Union[bool, float]
79
+ matcher: MatcherConfig
80
 
81
 
82
  @dataclass
83
  class SchedulerConfig:
84
  type: str
85
+ warmup: Dict[str, Union[int, float]]
86
+ args: Dict[str, Any]
87
 
88
 
89
  @dataclass
 
92
  decay: float
93
 
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  @dataclass
96
  class TrainConfig:
97
+ task: str
98
+ dataset: DatasetConfig
99
  epoch: int
100
+ data: DataConfig
101
  optimizer: OptimizerConfig
102
+ loss: LossConfig
103
  scheduler: SchedulerConfig
104
  ema: EMAConfig
 
105
 
106
 
107
  @dataclass
108
+ class NMSConfig:
109
+ min_confidence: int
110
+ min_iou: int
 
 
 
 
 
 
 
111
 
112
 
113
  @dataclass
114
+ class InferenceConfig:
115
+ task: str
116
+ nms: NMSConfig
 
117
 
118
 
119
  @dataclass
120
+ class Config:
121
+ task: Union[TrainConfig, InferenceConfig]
122
+ model: Model
123
+ name: str
124
 
125
+ device: Union[str, int, List[int]]
126
+ cpu_num: int
127
 
128
+ class_num: int
129
+ image_size: List[int]
 
 
130
 
131
+ out_path: str
132
+ exist_ok: bool
133
 
134
+ lucky_number: 10
135
+ use_wandb: bool
136
+ use_TensorBoard: bool
 
 
137
 
138
 
139
  @dataclass
 
146
 
147
  def __post_init__(self):
148
  super().__init__()
 
 
 
 
 
 
 
 
yolo/config/config.yaml CHANGED
@@ -1,13 +1,12 @@
1
  hydra:
2
  run:
3
- dir: ./runs
4
 
5
- defaults:
6
- - data: coco
7
- - download: ../data/download
8
- - augmentation: ../data/augmentation
9
- - model: v9-c
10
- - hyper: default
11
- - _self_
12
-
13
  name: v9-dev
 
 
 
 
 
 
 
 
1
  hydra:
2
  run:
3
+ dir: runs
4
 
 
 
 
 
 
 
 
 
5
  name: v9-dev
6
+
7
+ defaults:
8
+ - _self_
9
+ - task: train
10
+ - model: v9-c
11
+ - general
12
+
yolo/config/data/augmentation.yaml DELETED
@@ -1,3 +0,0 @@
1
- Mosaic: 1
2
- # MixUp: 1
3
- HorizontalFlip: 0.5
 
 
 
 
yolo/config/data/coco.yaml DELETED
@@ -1 +0,0 @@
1
- path: data/coco
 
 
yolo/config/general.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ deivce: [0]
2
+ cpu_num: 16
3
+
4
+ class_num: 80
5
+ image_size: [640, 640]
6
+
7
+ out_path: runs
8
+ exist_ok: True
9
+
10
+ lucky_number: 10
11
+ use_wandb: False
12
+ use_TensorBoard: False
yolo/config/hyper/default.yaml DELETED
@@ -1,48 +0,0 @@
1
- general:
2
- out_path: runs
3
- task: train
4
- deivce: [0]
5
- cpu_num: 16
6
- use_wandb: False
7
- lucky_number: 10
8
- exist_ok: True
9
- resume_train: False
10
- use_TensorBoard: False
11
- data:
12
- batch_size: 16
13
- class_num: 80
14
- image_size: [640, 640]
15
- shuffle: True
16
- pin_memory: True
17
- train:
18
- epoch: 500
19
- optimizer:
20
- type: SGD
21
- args:
22
- lr: 0.01
23
- weight_decay: 0.0005
24
- momentum: 0.937
25
- loss:
26
- objective:
27
- BCELoss: 0.5
28
- BoxLoss: 7.5
29
- DFLoss: 1.5
30
- aux:
31
- 0.25
32
- matcher:
33
- iou: CIoU
34
- topk: 10
35
- factor:
36
- iou: 6.0
37
- cls: 0.5
38
- scheduler:
39
- type: LinearLR
40
- warmup:
41
- epochs: 3.0
42
- args:
43
- total_iters: ${hyper.train.epoch}
44
- start_factor: 1
45
- end_factor: 0.01
46
- ema:
47
- enabled: true
48
- decay: 0.995
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
yolo/config/{data/download.yaml → task/dataset/coco.yaml} RENAMED
@@ -1,6 +1,6 @@
1
- auto: True
2
- save_path: data/coco
3
- datasets:
4
  images:
5
  base_url: http://images.cocodataset.org/zips/
6
  train2017:
@@ -15,7 +15,4 @@ datasets:
15
  annotations:
16
  base_url: http://images.cocodataset.org/annotations/
17
  annotations:
18
- file_name: annotations_trainval2017
19
- hydra:
20
- run:
21
- dir: ./runs
 
1
+ path: data/coco
2
+
3
+ auto_download:
4
  images:
5
  base_url: http://images.cocodataset.org/zips/
6
  train2017:
 
15
  annotations:
16
  base_url: http://images.cocodataset.org/annotations/
17
  annotations:
18
+ file_name: annotations_trainval2017
 
 
 
yolo/config/task/train.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ task: train
2
+ defaults:
3
+ - dataset: coco
4
+
5
+ epoch: 500
6
+
7
+ data:
8
+ batch_size: 16
9
+ shuffle: True
10
+ pin_memory: True
11
+ data_augment:
12
+ Mosaic: 1
13
+ # MixUp: 1
14
+ HorizontalFlip: 0.5
15
+
16
+ optimizer:
17
+ type: SGD
18
+ args:
19
+ lr: 0.01
20
+ weight_decay: 0.0005
21
+ momentum: 0.937
22
+
23
+ loss:
24
+ objective:
25
+ BCELoss: 0.5
26
+ BoxLoss: 7.5
27
+ DFLoss: 1.5
28
+ aux:
29
+ 0.25
30
+ matcher:
31
+ iou: CIoU
32
+ topk: 10
33
+ factor:
34
+ iou: 6.0
35
+ cls: 0.5
36
+
37
+ scheduler:
38
+ type: LinearLR
39
+ warmup:
40
+ epochs: 3.0
41
+ args:
42
+ total_iters: ${task.epoch}
43
+ start_factor: 1
44
+ end_factor: 0.01
45
+
46
+ ema:
47
+ enabled: true
48
+ decay: 0.995
yolo/model/yolo.py CHANGED
@@ -123,7 +123,7 @@ def get_model(cfg: Config) -> YOLO:
123
  YOLO: An instance of the model defined by the given configuration.
124
  """
125
  OmegaConf.set_struct(cfg.model, False)
126
- model = YOLO(cfg.model, cfg.hyper.data.class_num)
127
  logger.info("✅ Success load model")
128
  log_model_structure(model.model)
129
  draw_model(model=model)
 
123
  YOLO: An instance of the model defined by the given configuration.
124
  """
125
  OmegaConf.set_struct(cfg.model, False)
126
+ model = YOLO(cfg.model, cfg.class_num)
127
  logger.info("✅ Success load model")
128
  log_model_structure(model.model)
129
  draw_model(model=model)
yolo/tools/data_loader.py CHANGED
@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader, Dataset
12
  from torchvision.transforms import functional as TF
13
  from tqdm.rich import tqdm
14
 
15
- from yolo.config.config import Config
16
  from yolo.tools.data_augmentation import (
17
  AugmentationComposer,
18
  HorizontalFlip,
@@ -20,6 +20,7 @@ from yolo.tools.data_augmentation import (
20
  Mosaic,
21
  VerticalFlip,
22
  )
 
23
  from yolo.tools.drawer import draw_bboxes
24
  from yolo.utils.dataset_utils import (
25
  create_image_metadata,
@@ -29,16 +30,16 @@ from yolo.utils.dataset_utils import (
29
 
30
 
31
  class YoloDataset(Dataset):
32
- def __init__(self, config: dict, phase: str = "train2017", image_size: int = 640):
33
- dataset_cfg = config.data
34
- augment_cfg = config.augmentation
35
- phase_name = dataset_cfg.get(phase, phase)
36
  self.image_size = image_size
37
 
38
  transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
39
  self.transform = AugmentationComposer(transforms, self.image_size)
40
  self.transform.get_more_data = self.get_more_data
41
- self.data = self.load_data(dataset_cfg.path, phase_name)
42
 
43
  def load_data(self, dataset_path, phase_name):
44
  """
@@ -159,15 +160,15 @@ class YoloDataset(Dataset):
159
  class YoloDataLoader(DataLoader):
160
  def __init__(self, config: Config):
161
  """Initializes the YoloDataLoader with hydra-config files."""
162
- hyper = config.hyper.data
163
- dataset = YoloDataset(config)
164
 
165
  super().__init__(
166
  dataset,
167
- batch_size=hyper.batch_size,
168
- shuffle=hyper.shuffle,
169
- num_workers=config.hyper.general.cpu_num,
170
- pin_memory=hyper.pin_memory,
171
  collate_fn=self.collate_fn,
172
  )
173
 
@@ -197,7 +198,10 @@ class YoloDataLoader(DataLoader):
197
  return batch_images, batch_targets
198
 
199
 
200
- def create_dataloader(config):
 
 
 
201
  return YoloDataLoader(config)
202
 
203
 
@@ -211,7 +215,7 @@ if __name__ == "__main__":
211
  import sys
212
 
213
  sys.path.append("./")
214
- from tools.logging_utils import custom_logger
215
 
216
  custom_logger()
217
  main()
 
12
  from torchvision.transforms import functional as TF
13
  from tqdm.rich import tqdm
14
 
15
+ from yolo.config.config import Config, TrainConfig
16
  from yolo.tools.data_augmentation import (
17
  AugmentationComposer,
18
  HorizontalFlip,
 
20
  Mosaic,
21
  VerticalFlip,
22
  )
23
+ from yolo.tools.dataset_preparation import prepare_dataset
24
  from yolo.tools.drawer import draw_bboxes
25
  from yolo.utils.dataset_utils import (
26
  create_image_metadata,
 
30
 
31
 
32
  class YoloDataset(Dataset):
33
+ def __init__(self, config: TrainConfig, phase: str = "train2017", image_size: int = 640):
34
+ augment_cfg = config.data.data_augment
35
+ # TODO: add yaml -> train: train2017
36
+ phase_name = config.dataset.auto_download.get(phase, phase)
37
  self.image_size = image_size
38
 
39
  transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
40
  self.transform = AugmentationComposer(transforms, self.image_size)
41
  self.transform.get_more_data = self.get_more_data
42
+ self.data = self.load_data(config.dataset.path, phase_name)
43
 
44
  def load_data(self, dataset_path, phase_name):
45
  """
 
160
  class YoloDataLoader(DataLoader):
161
  def __init__(self, config: Config):
162
  """Initializes the YoloDataLoader with hydra-config files."""
163
+ data_cfg = config.task.data
164
+ dataset = YoloDataset(config.task)
165
 
166
  super().__init__(
167
  dataset,
168
+ batch_size=data_cfg.batch_size,
169
+ shuffle=data_cfg.shuffle,
170
+ num_workers=config.cpu_num,
171
+ pin_memory=data_cfg.pin_memory,
172
  collate_fn=self.collate_fn,
173
  )
174
 
 
198
  return batch_images, batch_targets
199
 
200
 
201
+ def create_dataloader(config: Config):
202
+ if config.task.dataset.auto_download:
203
+ prepare_dataset(config.task.dataset)
204
+
205
  return YoloDataLoader(config)
206
 
207
 
 
215
  import sys
216
 
217
  sys.path.append("./")
218
+ from utils.logging_utils import custom_logger
219
 
220
  custom_logger()
221
  main()
yolo/tools/dataset_preparation.py CHANGED
@@ -6,6 +6,8 @@ from hydra import main
6
  from loguru import logger
7
  from tqdm import tqdm
8
 
 
 
9
 
10
  def download_file(url, destination):
11
  """
@@ -45,12 +47,12 @@ def check_files(directory, expected_count=None):
45
 
46
 
47
  @main(config_path="../config/data", config_name="download", version_base=None)
48
- def prepare_dataset(cfg):
49
  """
50
  Prepares dataset by downloading and unzipping if necessary.
51
  """
52
- data_dir = cfg.save_path
53
- for data_type, settings in cfg.datasets.items():
54
  base_url = settings["base_url"]
55
  for dataset_type, dataset_args in settings.items():
56
  if dataset_type == "base_url":
 
6
  from loguru import logger
7
  from tqdm import tqdm
8
 
9
+ from yolo.config.config import DatasetConfig
10
+
11
 
12
  def download_file(url, destination):
13
  """
 
47
 
48
 
49
  @main(config_path="../config/data", config_name="download", version_base=None)
50
+ def prepare_dataset(cfg: DatasetConfig):
51
  """
52
  Prepares dataset by downloading and unzipping if necessary.
53
  """
54
+ data_dir = cfg.path
55
+ for data_type, settings in cfg.auto_download.items():
56
  base_url = settings["base_url"]
57
  for dataset_type, dataset_args in settings.items():
58
  if dataset_type == "base_url":
yolo/tools/format_converters.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def convert_weight(old_state_dict, new_state_dict, model_size: int = 38):
2
+ # TODO: need to refactor
3
+ for idx in range(model_size):
4
+ new_list, old_list = [], []
5
+ for weight_name, weight_value in new_state_dict.items():
6
+ if weight_name.split(".")[0] == str(idx):
7
+ new_list.append((weight_name, None))
8
+ for weight_name, weight_value in old_state_dict.items():
9
+ if f"model.{idx+1}." in weight_name:
10
+ old_list.append((weight_name, weight_value))
11
+ if len(new_list) == len(old_list):
12
+ for (weight_name, _), (_, weight_value) in zip(new_list, old_list):
13
+ new_state_dict[weight_name] = weight_value
14
+ else:
15
+ for weight_name, weight_value in old_list:
16
+ if "dfl" in weight_name:
17
+ continue
18
+ _, _, conv_name, conv_idx, *details = weight_name.split(".")
19
+ if conv_name == "cv4" or conv_name == "cv5":
20
+ conv_idx = str(int(conv_idx) + 3)
21
+
22
+ if conv_name == "cv2" or conv_name == "cv4":
23
+ conv_task = "anchor_conv"
24
+ if conv_name == "cv3" or conv_name == "cv5":
25
+ conv_task = "class_conv"
26
+
27
+ weight_name = ".".join(["37", "heads", conv_idx, conv_task, *details])
28
+ new_state_dict[weight_name] = weight_value
29
+ return new_state_dict
yolo/tools/loss_functions.py CHANGED
@@ -75,8 +75,8 @@ class DFLoss(nn.Module):
75
  class YOLOLoss:
76
  def __init__(self, cfg: Config) -> None:
77
  self.reg_max = cfg.model.anchor.reg_max
78
- self.class_num = cfg.hyper.data.class_num
79
- self.image_size = list(cfg.hyper.data.image_size)
80
  self.strides = cfg.model.anchor.strides
81
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
82
 
@@ -89,7 +89,7 @@ class YOLOLoss:
89
  self.dfl = DFLoss(self.anchors, self.scaler, self.reg_max)
90
  self.iou = BoxLoss()
91
 
92
- self.matcher = BoxMatcher(cfg.hyper.train.loss.matcher, self.class_num, self.anchors)
93
  self.box_converter = AnchorBoxConverter(cfg, device)
94
 
95
  def separate_anchor(self, anchors):
@@ -127,11 +127,11 @@ class YOLOLoss:
127
  class DualLoss:
128
  def __init__(self, cfg: Config) -> None:
129
  self.loss = YOLOLoss(cfg)
130
- self.aux_rate = cfg.hyper.train.loss.aux
131
 
132
- self.iou_rate = cfg.hyper.train.loss.objective["BoxLoss"]
133
- self.dfl_rate = cfg.hyper.train.loss.objective["DFLoss"]
134
- self.cls_rate = cfg.hyper.train.loss.objective["BCELoss"]
135
 
136
  def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]:
137
  targets[:, :, 1:] = targets[:, :, 1:] * self.loss.scale_up
 
75
  class YOLOLoss:
76
  def __init__(self, cfg: Config) -> None:
77
  self.reg_max = cfg.model.anchor.reg_max
78
+ self.class_num = cfg.class_num
79
+ self.image_size = list(cfg.image_size)
80
  self.strides = cfg.model.anchor.strides
81
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
82
 
 
89
  self.dfl = DFLoss(self.anchors, self.scaler, self.reg_max)
90
  self.iou = BoxLoss()
91
 
92
+ self.matcher = BoxMatcher(cfg.task.loss.matcher, self.class_num, self.anchors)
93
  self.box_converter = AnchorBoxConverter(cfg, device)
94
 
95
  def separate_anchor(self, anchors):
 
127
  class DualLoss:
128
  def __init__(self, cfg: Config) -> None:
129
  self.loss = YOLOLoss(cfg)
130
+ self.aux_rate = cfg.task.loss.aux
131
 
132
+ self.iou_rate = cfg.task.loss.objective["BoxLoss"]
133
+ self.dfl_rate = cfg.task.loss.objective["DFLoss"]
134
+ self.cls_rate = cfg.task.loss.objective["BCELoss"]
135
 
136
  def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]:
137
  targets[:, :, 1:] = targets[:, :, 1:] * self.loss.scale_up
yolo/tools/trainer.py CHANGED
@@ -18,7 +18,7 @@ from yolo.utils.model_utils import (
18
 
19
  class ModelTrainer:
20
  def __init__(self, cfg: Config, save_path: str, device):
21
- train_cfg: TrainConfig = cfg.hyper.train
22
  model = get_model(cfg)
23
 
24
  self.model = model.to(device)
 
18
 
19
  class ModelTrainer:
20
  def __init__(self, cfg: Config, save_path: str, device):
21
+ train_cfg: TrainConfig = cfg.task
22
  model = get_model(cfg)
23
 
24
  self.model = model.to(device)
yolo/utils/bounding_box_utils.py CHANGED
@@ -7,7 +7,7 @@ from einops import rearrange
7
  from torch import Tensor
8
  from torchvision.ops import batched_nms
9
 
10
- from yolo.config.config import Config, MatcherConfig
11
 
12
 
13
  def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
@@ -127,8 +127,8 @@ def generate_anchors(image_size: List[int], strides: List[int], device):
127
  class AnchorBoxConverter:
128
  def __init__(self, cfg: Config, device: torch.device) -> None:
129
  self.reg_max = cfg.model.anchor.reg_max
130
- self.class_num = cfg.hyper.data.class_num
131
- self.image_size = list(cfg.hyper.data.image_size)
132
  self.strides = cfg.model.anchor.strides
133
 
134
  self.scale_up = torch.tensor(self.image_size * 2, device=device)
@@ -291,17 +291,17 @@ class BoxMatcher:
291
  return torch.cat([align_cls, align_bbox], dim=-1), valid_mask.bool()
292
 
293
 
294
- def bbox_nms(predicts: Tensor, min_conf: float = 0, min_iou: float = 0.5):
295
  cls_dist, bbox = predicts.split([80, 4], dim=-1)
296
 
297
  # filter class by confidence
298
  cls_val, cls_idx = cls_dist.max(dim=-1, keepdim=True)
299
- valid_mask = cls_val > min_conf
300
  valid_cls = cls_idx[valid_mask]
301
  valid_box = bbox[valid_mask.repeat(1, 1, 4)].view(-1, 4)
302
 
303
  batch_idx, *_ = torch.where(valid_mask)
304
- nms_idx = batched_nms(valid_box, valid_cls, batch_idx, min_iou)
305
  predicts_nms = []
306
  for idx in range(batch_idx.max() + 1):
307
  instance_idx = nms_idx[idx == batch_idx[nms_idx]]
 
7
  from torch import Tensor
8
  from torchvision.ops import batched_nms
9
 
10
+ from yolo.config.config import Config, MatcherConfig, NMSConfig
11
 
12
 
13
  def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
 
127
  class AnchorBoxConverter:
128
  def __init__(self, cfg: Config, device: torch.device) -> None:
129
  self.reg_max = cfg.model.anchor.reg_max
130
+ self.class_num = cfg.class_num
131
+ self.image_size = list(cfg.image_size)
132
  self.strides = cfg.model.anchor.strides
133
 
134
  self.scale_up = torch.tensor(self.image_size * 2, device=device)
 
291
  return torch.cat([align_cls, align_bbox], dim=-1), valid_mask.bool()
292
 
293
 
294
+ def bbox_nms(predicts: Tensor, nms_cfg: NMSConfig):
295
  cls_dist, bbox = predicts.split([80, 4], dim=-1)
296
 
297
  # filter class by confidence
298
  cls_val, cls_idx = cls_dist.max(dim=-1, keepdim=True)
299
+ valid_mask = cls_val > nms_cfg.min_confidence
300
  valid_cls = cls_idx[valid_mask]
301
  valid_box = bbox[valid_mask.repeat(1, 1, 4)].view(-1, 4)
302
 
303
  batch_idx, *_ = torch.where(valid_mask)
304
+ nms_idx = batched_nms(valid_box, valid_cls, batch_idx, nms_cfg.min_iou)
305
  predicts_nms = []
306
  for idx in range(batch_idx.max() + 1):
307
  instance_idx = nms_idx[idx == batch_idx[nms_idx]]
yolo/utils/logging_utils.py CHANGED
@@ -24,7 +24,7 @@ from rich.table import Table
24
  from torch import Tensor
25
  from torch.optim import Optimizer
26
 
27
- from yolo.config.config import Config, GeneralConfig, YOLOLayer
28
 
29
 
30
  def custom_logger():
@@ -110,11 +110,11 @@ def log_model_structure(model: List[YOLOLayer]):
110
  console.print(table)
111
 
112
 
113
- def validate_log_directory(general_cfg: GeneralConfig, exp_name):
114
- base_path = os.path.join(general_cfg.out_path, general_cfg.task)
115
  save_path = os.path.join(base_path, exp_name)
116
 
117
- if not general_cfg.exist_ok:
118
  index = 1
119
  old_exp_name = exp_name
120
  while os.path.isdir(save_path):
 
24
  from torch import Tensor
25
  from torch.optim import Optimizer
26
 
27
+ from yolo.config.config import Config, YOLOLayer
28
 
29
 
30
  def custom_logger():
 
110
  console.print(table)
111
 
112
 
113
+ def validate_log_directory(cfg: Config, exp_name: str):
114
+ base_path = os.path.join(cfg.out_path, cfg.task.task)
115
  save_path = os.path.join(base_path, exp_name)
116
 
117
+ if not cfg.exist_ok:
118
  index = 1
119
  old_exp_name = exp_name
120
  while os.path.isdir(save_path):