henry000 commited on
Commit
e9816da
Β·
2 Parent(s): 89c6a27 d1aee89

πŸ”€ [Merge] branch 'MODELv2' into DEPLOY

Browse files
yolo/config/model/v9-m.yaml ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ anchor:
2
+ reg_max: 16
3
+
4
+ model:
5
+ backbone:
6
+ - Conv:
7
+ args: {out_channels: 32, kernel_size: 3, stride: 2}
8
+ source: 0
9
+ - Conv:
10
+ args: {out_channels: 64, kernel_size: 3, stride: 2}
11
+ - RepNCSPELAN:
12
+ args: {out_channels: 128, part_channels: 128}
13
+
14
+ - AConv:
15
+ args: {out_channels: 240}
16
+ - RepNCSPELAN:
17
+ args: {out_channels: 240, part_channels: 240}
18
+ tags: B3
19
+
20
+ - AConv:
21
+ args: {out_channels: 360}
22
+ - RepNCSPELAN:
23
+ args: {out_channels: 360, part_channels: 360}
24
+ tags: B4
25
+
26
+ - AConv:
27
+ args: {out_channels: 480}
28
+ - RepNCSPELAN:
29
+ args: {out_channels: 480, part_channels: 480}
30
+ tags: B5
31
+
32
+ neck:
33
+ - SPPELAN:
34
+ args: {out_channels: 480}
35
+ tags: N3
36
+
37
+ - UpSample:
38
+ args: {scale_factor: 2, mode: nearest}
39
+ - Concat:
40
+ source: [-1, B4]
41
+ - RepNCSPELAN:
42
+ args: {out_channels: 360, part_channels: 360}
43
+ tags: N4
44
+
45
+ - UpSample:
46
+ args: {scale_factor: 2, mode: nearest}
47
+ - Concat:
48
+ source: [-1, B3]
49
+
50
+ head:
51
+ - RepNCSPELAN:
52
+ args: {out_channels: 240, part_channels: 240}
53
+ tags: P3
54
+
55
+ - AConv:
56
+ args: {out_channels: 184}
57
+ - Concat:
58
+ source: [-1, N4]
59
+ - RepNCSPELAN:
60
+ args: {out_channels: 360, part_channels: 360}
61
+ tags: P4
62
+
63
+ - AConv:
64
+ args: {out_channels: 240}
65
+ - Concat:
66
+ source: [-1, N3]
67
+ - RepNCSPELAN:
68
+ args: {out_channels: 480, part_channels: 480}
69
+ tags: P5
70
+
71
+ detection:
72
+ - MultiheadDetection:
73
+ source: [P3, P4, P5]
74
+ tags: Main
75
+ args:
76
+ reg_max: ${model.anchor.reg_max}
77
+ output: True
78
+
79
+ auxiliary:
80
+ - CBLinear:
81
+ source: B3
82
+ args: {out_channels: [240]}
83
+ tags: R3
84
+ - CBLinear:
85
+ source: B4
86
+ args: {out_channels: [240, 360]}
87
+ tags: R4
88
+ - CBLinear:
89
+ source: B5
90
+ args: {out_channels: [240, 360, 480]}
91
+ tags: R5
92
+
93
+ - Conv:
94
+ args: {out_channels: 32, kernel_size: 3, stride: 2}
95
+ source: 0
96
+ - Conv:
97
+ args: {out_channels: 64, kernel_size: 3, stride: 2}
98
+ - RepNCSPELAN:
99
+ args: {out_channels: 128, part_channels: 128}
100
+
101
+ - AConv:
102
+ args: {out_channels: 240}
103
+ - CBFuse:
104
+ source: [R3, R4, R5, -1]
105
+ args: {index: [0, 0, 0]}
106
+ - RepNCSPELAN:
107
+ args: {out_channels: 240, part_channels: 240}
108
+ tags: A3
109
+
110
+ - AConv:
111
+ args: {out_channels: 360}
112
+ - CBFuse:
113
+ source: [R4, R5, -1]
114
+ args: {index: [1, 1]}
115
+ - RepNCSPELAN:
116
+ args: {out_channels: 360, part_channels: 360}
117
+ tags: A4
118
+
119
+ - AConv:
120
+ args: {out_channels: 480}
121
+ - CBFuse:
122
+ source: [R5, -1]
123
+ args: {index: [2]}
124
+ - RepNCSPELAN:
125
+ args: {out_channels: 480, part_channels: 480}
126
+ tags: A5
127
+
128
+ - MultiheadDetection:
129
+ source: [A3, A4, A5]
130
+ tags: AUX
131
+ args:
132
+ reg_max: ${model.anchor.reg_max}
133
+ output: True
yolo/config/model/v9-s.yaml ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ anchor:
2
+ reg_max: 16
3
+
4
+ model:
5
+ backbone:
6
+ - Conv:
7
+ args: {out_channels: 32, kernel_size: 3, stride: 2}
8
+ source: 0
9
+ - Conv:
10
+ args: {out_channels: 64, kernel_size: 3, stride: 2}
11
+ - ELAN:
12
+ args: {out_channels: 64, part_channels: 64}
13
+
14
+ - AConv:
15
+ args: {out_channels: 128}
16
+ - RepNCSPELAN:
17
+ args:
18
+ out_channels: 128
19
+ part_channels: 128
20
+ csp_args: {repeat_num: 3}
21
+ tags: B3 # 18
22
+
23
+ - AConv:
24
+ args: {out_channels: 192}
25
+ - RepNCSPELAN:
26
+ args:
27
+ out_channels: 192
28
+ part_channels: 192
29
+ csp_args: {repeat_num: 3}
30
+ tags: B4
31
+
32
+ - AConv:
33
+ args: {out_channels: 256}
34
+ - RepNCSPELAN:
35
+ args:
36
+ out_channels: 256
37
+ part_channels: 256
38
+ csp_args: {repeat_num: 3}
39
+ tags: B5
40
+
41
+ neck:
42
+ - SPPELAN:
43
+ args: {out_channels: 256}
44
+ tags: N3
45
+
46
+ - UpSample:
47
+ args: {scale_factor: 2, mode: nearest}
48
+ - Concat:
49
+ source: [-1, B4]
50
+ - RepNCSPELAN:
51
+ args:
52
+ out_channels: 192
53
+ part_channels: 192
54
+ csp_args: {repeat_num: 3}
55
+ tags: N4
56
+
57
+ - UpSample:
58
+ args: {scale_factor: 2, mode: nearest}
59
+ - Concat:
60
+ source: [-1, B3]
61
+
62
+ - RepNCSPELAN:
63
+ args:
64
+ out_channels: 128
65
+ part_channels: 128
66
+ csp_args: {repeat_num: 3}
67
+ tags: P3
68
+ - AConv:
69
+ args: {out_channels: 96}
70
+ - Concat:
71
+ source: [-1, N4]
72
+
73
+ - RepNCSPELAN:
74
+ args:
75
+ out_channels: 192
76
+ part_channels: 192
77
+ csp_args: {repeat_num: 3}
78
+ tags: P4
79
+ - AConv:
80
+ args: {out_channels: 128}
81
+ - Concat:
82
+ source: [-1, N3]
83
+
84
+ - RepNCSPELAN:
85
+ args:
86
+ out_channels: 256
87
+ part_channels: 256
88
+ csp_args: {repeat_num: 3}
89
+ tags: P5
90
+
91
+ detection:
92
+ - MultiheadDetection:
93
+ source: [P3, P4, P5]
94
+ tags: Main
95
+ args:
96
+ reg_max: ${model.anchor.reg_max}
97
+ output: True
98
+
99
+ head:
100
+ - SPPELAN:
101
+ source: B5
102
+ args: {out_channels: 256}
103
+ tags: A5
104
+
105
+ - UpSample:
106
+ args: {scale_factor: 2, mode: nearest}
107
+ - Concat:
108
+ source: [-1, B4]
109
+
110
+ - RepNCSPELAN:
111
+ args:
112
+ out_channels: 192
113
+ part_channels: 192
114
+ csp_args: {repeat_num: 3}
115
+ tags: A4
116
+
117
+ - UpSample:
118
+ args: {scale_factor: 2, mode: nearest}
119
+ - Concat:
120
+ source: [-1, B3]
121
+
122
+ - RepNCSPELAN:
123
+ args:
124
+ out_channels: 128
125
+ part_channels: 128
126
+ csp_args: {repeat_num: 3}
127
+ tags: A3
128
+
129
+ - MultiheadDetection:
130
+ source: [A3, A4, A5]
131
+ tags: AUX
132
+ args:
133
+ reg_max: ${model.anchor.reg_max}
134
+ output: True
yolo/lazy.py CHANGED
@@ -13,17 +13,16 @@ from yolo.tools.data_loader import create_dataloader
13
  from yolo.tools.solver import ModelTester, ModelTrainer
14
  from yolo.utils.bounding_box_utils import Vec2Box
15
  from yolo.utils.deploy_utils import FastModelLoader
16
- from yolo.utils.logging_utils import custom_logger, validate_log_directory
17
 
18
 
19
  @hydra.main(config_path="config", config_name="config", version_base=None)
20
  def main(cfg: Config):
21
- custom_logger()
22
- save_path = validate_log_directory(cfg, exp_name=cfg.name)
23
  dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
24
  device = torch.device(cfg.device)
25
  if getattr(cfg.task, "fast_inference", False):
26
- model = FastModelLoader(cfg).load_model()
27
  device = torch.device(cfg.device)
28
  else:
29
  model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight, device=device)
@@ -31,11 +30,11 @@ def main(cfg: Config):
31
  vec2box = Vec2Box(model, cfg.image_size, device)
32
 
33
  if cfg.task.task == "train":
34
- trainer = ModelTrainer(cfg, model, vec2box, save_path, device)
35
  trainer.solve(dataloader)
36
 
37
  if cfg.task.task == "inference":
38
- tester = ModelTester(cfg, model, vec2box, save_path, device)
39
  tester.solve(dataloader)
40
 
41
 
 
13
  from yolo.tools.solver import ModelTester, ModelTrainer
14
  from yolo.utils.bounding_box_utils import Vec2Box
15
  from yolo.utils.deploy_utils import FastModelLoader
16
+ from yolo.utils.logging_utils import ProgressLogger
17
 
18
 
19
  @hydra.main(config_path="config", config_name="config", version_base=None)
20
  def main(cfg: Config):
21
+ progress = ProgressLogger(cfg, exp_name=cfg.name)
 
22
  dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
23
  device = torch.device(cfg.device)
24
  if getattr(cfg.task, "fast_inference", False):
25
+ model = FastModelLoader(cfg, device).load_model()
26
  device = torch.device(cfg.device)
27
  else:
28
  model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight, device=device)
 
30
  vec2box = Vec2Box(model, cfg.image_size, device)
31
 
32
  if cfg.task.task == "train":
33
+ trainer = ModelTrainer(cfg, model, vec2box, progress, device)
34
  trainer.solve(dataloader)
35
 
36
  if cfg.task.task == "inference":
37
+ tester = ModelTester(cfg, model, vec2box, progress, device)
38
  tester.solve(dataloader)
39
 
40
 
yolo/model/module.py CHANGED
@@ -192,6 +192,36 @@ class RepNCSP(nn.Module):
192
  return self.conv3(torch.cat((x1, x2), dim=1))
193
 
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  class RepNCSPELAN(nn.Module):
196
  """RepNCSPELAN block combining RepNCSP blocks with ELAN structure."""
197
 
@@ -230,6 +260,21 @@ class RepNCSPELAN(nn.Module):
230
  return x5
231
 
232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  class ADown(nn.Module):
234
  """Downsampling module combining average and max pooling with convolution for feature reduction."""
235
 
@@ -498,26 +543,6 @@ class CSPDark(nn.Module):
498
  return self.cv2(torch.cat((self.cb(y[0]), y[1]), 1))
499
 
500
 
501
- # ELAN
502
- class ELAN(nn.Module):
503
- # ELAN
504
- def __init__(self, in_channels, out_channels, med_channels, elan_repeat=2, cb_repeat=2, ratio=1.0):
505
-
506
- super().__init__()
507
-
508
- h_channels = med_channels // 2
509
- self.cv1 = Conv(in_channels, med_channels, 1, 1)
510
- self.cb = nn.ModuleList(ConvBlock(h_channels, repeat=cb_repeat, ratio=ratio) for _ in range(elan_repeat))
511
- self.cv2 = Conv((2 + elan_repeat) * h_channels, out_channels, 1, 1)
512
-
513
- def forward(self, x):
514
-
515
- y = list(self.cv1(x).chunk(2, 1))
516
- y.extend((m(y[-1])) for m in self.cb)
517
-
518
- return self.cv2(torch.cat(y, 1))
519
-
520
-
521
  class CSPELAN(nn.Module):
522
  # ELAN
523
  def __init__(self, in_channels, out_channels, med_channels, elan_repeat=2, cb_repeat=2, ratio=1.0):
 
192
  return self.conv3(torch.cat((x1, x2), dim=1))
193
 
194
 
195
+ class ELAN(nn.Module):
196
+ """ELAN structure."""
197
+
198
+ def __init__(
199
+ self,
200
+ in_channels: int,
201
+ out_channels: int,
202
+ part_channels: int,
203
+ *,
204
+ process_channels: Optional[int] = None,
205
+ **kwargs,
206
+ ):
207
+ super().__init__()
208
+
209
+ if process_channels is None:
210
+ process_channels = part_channels // 2
211
+
212
+ self.conv1 = Conv(in_channels, part_channels, 1, **kwargs)
213
+ self.conv2 = Conv(part_channels // 2, process_channels, 3, padding=1, **kwargs)
214
+ self.conv3 = Conv(process_channels, process_channels, 3, padding=1, **kwargs)
215
+ self.conv4 = Conv(part_channels + 2 * process_channels, out_channels, 1, **kwargs)
216
+
217
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
218
+ x1, x2 = self.conv1(x).chunk(2, 1)
219
+ x3 = self.conv2(x2)
220
+ x4 = self.conv3(x3)
221
+ x5 = self.conv4(torch.cat([x1, x2, x3, x4], dim=1))
222
+ return x5
223
+
224
+
225
  class RepNCSPELAN(nn.Module):
226
  """RepNCSPELAN block combining RepNCSP blocks with ELAN structure."""
227
 
 
260
  return x5
261
 
262
 
263
+ class AConv(nn.Module):
264
+ """Downsampling module combining average and max pooling with convolution for feature reduction."""
265
+
266
+ def __init__(self, in_channels: int, out_channels: int):
267
+ super().__init__()
268
+ mid_layer = {"kernel_size": 3, "stride": 2}
269
+ self.avg_pool = Pool("avg", kernel_size=2, stride=1)
270
+ self.conv = Conv(in_channels, out_channels, **mid_layer)
271
+
272
+ def forward(self, x: Tensor) -> Tensor:
273
+ x = self.avg_pool(x)
274
+ x = self.conv(x)
275
+ return x
276
+
277
+
278
  class ADown(nn.Module):
279
  """Downsampling module combining average and max pooling with convolution for feature reduction."""
280
 
 
543
  return self.cv2(torch.cat((self.cb(y[0]), y[1]), 1))
544
 
545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
546
  class CSPELAN(nn.Module):
547
  # ELAN
548
  def __init__(self, in_channels, out_channels, med_channels, elan_repeat=2, cb_repeat=2, ratio=1.0):
yolo/model/yolo.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- from typing import Any, Dict, List, Union
3
 
4
  import torch
5
  from loguru import logger
@@ -43,7 +43,7 @@ class YOLO(nn.Module):
43
  source = self.get_source_idx(layer_info.get("source", -1), layer_idx)
44
 
45
  # Find in channels
46
- if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "CBLinear"]):
47
  layer_args["in_channels"] = output_dim[source]
48
  if "Detection" in layer_type:
49
  layer_args["in_channels"] = [output_dim[idx] for idx in source]
@@ -81,7 +81,7 @@ class YOLO(nn.Module):
81
  return output
82
 
83
  def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
84
- if any(module in layer_type for module in ["Conv", "ELAN", "ADown"]):
85
  return layer_args["out_channels"]
86
  if layer_type == "CBFuse":
87
  return output_dim[source[-1]]
@@ -117,9 +117,7 @@ class YOLO(nn.Module):
117
  raise ValueError(f"Unsupported layer type: {layer_type}")
118
 
119
 
120
- def create_model(
121
- model_cfg: ModelConfig, class_num: int = 80, weight_path: str = "weights/v9-c.pt", device: device = device("cuda")
122
- ) -> YOLO:
123
  """Constructs and returns a model from a Dictionary configuration file.
124
 
125
  Args:
@@ -135,8 +133,9 @@ def create_model(
135
  if not os.path.exists(weight_path):
136
  logger.info(f"🌐 Weight {weight_path} not found, try downloading")
137
  prepare_weight(weight_path=weight_path)
138
- model.model.load_state_dict(torch.load(weight_path, map_location=device))
139
- logger.info("βœ… Success load model weight")
 
140
 
141
  log_model_structure(model.model)
142
  draw_model(model=model)
 
1
  import os
2
+ from typing import Dict, List, Optional, Union
3
 
4
  import torch
5
  from loguru import logger
 
43
  source = self.get_source_idx(layer_info.get("source", -1), layer_idx)
44
 
45
  # Find in channels
46
+ if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "AConv", "CBLinear"]):
47
  layer_args["in_channels"] = output_dim[source]
48
  if "Detection" in layer_type:
49
  layer_args["in_channels"] = [output_dim[idx] for idx in source]
 
81
  return output
82
 
83
  def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
84
+ if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "AConv"]):
85
  return layer_args["out_channels"]
86
  if layer_type == "CBFuse":
87
  return output_dim[source[-1]]
 
117
  raise ValueError(f"Unsupported layer type: {layer_type}")
118
 
119
 
120
+ def create_model(model_cfg: ModelConfig, weight_path: Optional[str], device: device, class_num: int = 80) -> YOLO:
 
 
121
  """Constructs and returns a model from a Dictionary configuration file.
122
 
123
  Args:
 
133
  if not os.path.exists(weight_path):
134
  logger.info(f"🌐 Weight {weight_path} not found, try downloading")
135
  prepare_weight(weight_path=weight_path)
136
+ if os.path.exists(weight_path):
137
+ model.model.load_state_dict(torch.load(weight_path, map_location=device), strict=False)
138
+ logger.info("βœ… Success load model weight")
139
 
140
  log_model_structure(model.model)
141
  draw_model(model=model)
yolo/tools/data_loader.py CHANGED
@@ -23,7 +23,6 @@ from yolo.tools.data_augmentation import (
23
  VerticalFlip,
24
  )
25
  from yolo.tools.dataset_preparation import prepare_dataset
26
- from yolo.tools.drawer import draw_bboxes
27
  from yolo.utils.dataset_utils import (
28
  create_image_metadata,
29
  locate_label_paths,
@@ -204,7 +203,7 @@ def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: st
204
  return StreamDataLoader(data_cfg)
205
 
206
  if dataset_cfg.auto_download:
207
- prepare_dataset(dataset_cfg)
208
 
209
  return YoloDataLoader(data_cfg, dataset_cfg, task)
210
 
 
23
  VerticalFlip,
24
  )
25
  from yolo.tools.dataset_preparation import prepare_dataset
 
26
  from yolo.utils.dataset_utils import (
27
  create_image_metadata,
28
  locate_label_paths,
 
203
  return StreamDataLoader(data_cfg)
204
 
205
  if dataset_cfg.auto_download:
206
+ prepare_dataset(dataset_cfg, task)
207
 
208
  return YoloDataLoader(data_cfg, dataset_cfg, task)
209
 
yolo/tools/dataset_preparation.py CHANGED
@@ -52,7 +52,7 @@ def check_files(directory, expected_count=None):
52
  return len(files) == expected_count if expected_count is not None else bool(files)
53
 
54
 
55
- def prepare_dataset(dataset_cfg: DatasetConfig):
56
  """
57
  Prepares dataset by downloading and unzipping if necessary.
58
  """
@@ -60,8 +60,8 @@ def prepare_dataset(dataset_cfg: DatasetConfig):
60
  for data_type, settings in dataset_cfg.auto_download.items():
61
  base_url = settings["base_url"]
62
  for dataset_type, dataset_args in settings.items():
63
- if dataset_type == "base_url":
64
- continue # Skip the base_url entry
65
  file_name = f"{dataset_args.get('file_name', dataset_type)}.zip"
66
  url = f"{base_url}{file_name}"
67
  local_zip_path = os.path.join(data_dir, file_name)
 
52
  return len(files) == expected_count if expected_count is not None else bool(files)
53
 
54
 
55
+ def prepare_dataset(dataset_cfg: DatasetConfig, task: str):
56
  """
57
  Prepares dataset by downloading and unzipping if necessary.
58
  """
 
60
  for data_type, settings in dataset_cfg.auto_download.items():
61
  base_url = settings["base_url"]
62
  for dataset_type, dataset_args in settings.items():
63
+ if dataset_type != "annotations" and dataset_cfg.get(task, task) != dataset_type:
64
+ continue
65
  file_name = f"{dataset_args.get('file_name', dataset_type)}.zip"
66
  url = f"{base_url}{file_name}"
67
  local_zip_path = os.path.join(data_dir, file_name)
yolo/tools/solver.py CHANGED
@@ -14,7 +14,7 @@ from yolo.tools.data_loader import StreamDataLoader, create_dataloader
14
  from yolo.tools.drawer import draw_bboxes
15
  from yolo.tools.loss_functions import create_loss_function
16
  from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms, calculate_map
17
- from yolo.utils.logging_utils import ProgressTracker
18
  from yolo.utils.model_utils import (
19
  ExponentialMovingAverage,
20
  create_optimizer,
@@ -23,7 +23,7 @@ from yolo.utils.model_utils import (
23
 
24
 
25
  class ModelTrainer:
26
- def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, save_path: str, device):
27
  train_cfg: TrainConfig = cfg.task
28
  self.model = model
29
  self.vec2box = vec2box
@@ -31,11 +31,11 @@ class ModelTrainer:
31
  self.optimizer = create_optimizer(model, train_cfg.optimizer)
32
  self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
33
  self.loss_fn = create_loss_function(cfg, vec2box)
34
- self.progress = ProgressTracker(cfg.name, save_path, cfg.use_wandb)
35
  self.num_epochs = cfg.task.epoch
36
 
37
  self.validation_dataloader = create_dataloader(cfg.task.validation.data, cfg.dataset, cfg.task.validation.task)
38
- self.validator = ModelValidator(cfg.task.validation, model, vec2box, save_path, device, self.progress)
39
 
40
  if getattr(train_cfg.ema, "enabled", False):
41
  self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
@@ -102,14 +102,15 @@ class ModelTrainer:
102
 
103
 
104
  class ModelTester:
105
- def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, save_path: str, device):
106
  self.model = model
107
  self.device = device
108
  self.vec2box = vec2box
109
- self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb)
110
 
111
  self.nms = cfg.task.nms
112
- self.save_path = save_path
 
113
  self.save_predict = getattr(cfg.task, "save_predict", None)
114
  self.idx2label = cfg.class_list
115
 
@@ -142,8 +143,7 @@ class ModelTester:
142
  break
143
  if not self.save_predict:
144
  continue
145
-
146
- if self.save_predict == False:
147
  save_image_path = os.path.join(self.save_path, f"frame{idx:03d}.png")
148
  img.save(save_image_path)
149
  logger.info(f"πŸ’Ύ Saved visualize image at {save_image_path}")
@@ -164,16 +164,13 @@ class ModelValidator:
164
  validation_cfg: ValidationConfig,
165
  model: YOLO,
166
  vec2box: Vec2Box,
167
- save_path: str,
168
  device,
169
- # TODO: think Progress?
170
- progress: ProgressTracker,
171
  ):
172
  self.model = model
173
  self.vec2box = vec2box
174
  self.device = device
175
  self.progress = progress
176
- self.save_path = save_path
177
 
178
  self.nms = validation_cfg.nms
179
 
 
14
  from yolo.tools.drawer import draw_bboxes
15
  from yolo.tools.loss_functions import create_loss_function
16
  from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms, calculate_map
17
+ from yolo.utils.logging_utils import ProgressLogger
18
  from yolo.utils.model_utils import (
19
  ExponentialMovingAverage,
20
  create_optimizer,
 
23
 
24
 
25
  class ModelTrainer:
26
+ def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device):
27
  train_cfg: TrainConfig = cfg.task
28
  self.model = model
29
  self.vec2box = vec2box
 
31
  self.optimizer = create_optimizer(model, train_cfg.optimizer)
32
  self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
33
  self.loss_fn = create_loss_function(cfg, vec2box)
34
+ self.progress = progress
35
  self.num_epochs = cfg.task.epoch
36
 
37
  self.validation_dataloader = create_dataloader(cfg.task.validation.data, cfg.dataset, cfg.task.validation.task)
38
+ self.validator = ModelValidator(cfg.task.validation, model, vec2box, progress, device, self.progress)
39
 
40
  if getattr(train_cfg.ema, "enabled", False):
41
  self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
 
102
 
103
 
104
  class ModelTester:
105
+ def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device):
106
  self.model = model
107
  self.device = device
108
  self.vec2box = vec2box
109
+ self.progress = progress
110
 
111
  self.nms = cfg.task.nms
112
+ self.save_path = os.path.join(progress.save_path, "images")
113
+ os.makedirs(self.save_path, exist_ok=True)
114
  self.save_predict = getattr(cfg.task, "save_predict", None)
115
  self.idx2label = cfg.class_list
116
 
 
143
  break
144
  if not self.save_predict:
145
  continue
146
+ if self.save_predict != False:
 
147
  save_image_path = os.path.join(self.save_path, f"frame{idx:03d}.png")
148
  img.save(save_image_path)
149
  logger.info(f"πŸ’Ύ Saved visualize image at {save_image_path}")
 
164
  validation_cfg: ValidationConfig,
165
  model: YOLO,
166
  vec2box: Vec2Box,
 
167
  device,
168
+ progress: ProgressLogger,
 
169
  ):
170
  self.model = model
171
  self.vec2box = vec2box
172
  self.device = device
173
  self.progress = progress
 
174
 
175
  self.nms = validation_cfg.nms
176
 
yolo/utils/logging_utils.py CHANGED
@@ -38,15 +38,18 @@ def custom_logger(quite: bool = False):
38
  )
39
 
40
 
41
- class ProgressTracker:
42
- def __init__(self, exp_name: str, save_path: str, use_wandb: bool = False):
 
 
 
43
  self.progress = Progress(
44
  TextColumn("[progress.description]{task.description}"),
45
  BarColumn(bar_width=None),
46
  TextColumn("{task.completed:.0f}/{task.total:.0f}"),
47
  TimeRemainingColumn(),
48
  )
49
- self.use_wandb = use_wandb
50
  if self.use_wandb:
51
  wandb.errors.term._log = custom_wandb_log
52
  self.wandb = wandb.init(
 
38
  )
39
 
40
 
41
+ class ProgressLogger:
42
+ def __init__(self, cfg: Config, exp_name: str):
43
+ custom_logger(getattr(cfg, "quite", False))
44
+ self.save_path = validate_log_directory(cfg, exp_name=cfg.name)
45
+
46
  self.progress = Progress(
47
  TextColumn("[progress.description]{task.description}"),
48
  BarColumn(bar_width=None),
49
  TextColumn("{task.completed:.0f}/{task.total:.0f}"),
50
  TimeRemainingColumn(),
51
  )
52
+ self.use_wandb = cfg.use_wandb
53
  if self.use_wandb:
54
  wandb.errors.term._log = custom_wandb_log
55
  self.wandb = wandb.init(