π [Merge] branch 'MODELv2' into DEPLOY
Browse files- yolo/config/model/v9-m.yaml +133 -0
- yolo/config/model/v9-s.yaml +134 -0
- yolo/lazy.py +5 -6
- yolo/model/module.py +45 -20
- yolo/model/yolo.py +7 -8
- yolo/tools/data_loader.py +1 -2
- yolo/tools/dataset_preparation.py +3 -3
- yolo/tools/solver.py +10 -13
- yolo/utils/logging_utils.py +6 -3
yolo/config/model/v9-m.yaml
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
anchor:
|
2 |
+
reg_max: 16
|
3 |
+
|
4 |
+
model:
|
5 |
+
backbone:
|
6 |
+
- Conv:
|
7 |
+
args: {out_channels: 32, kernel_size: 3, stride: 2}
|
8 |
+
source: 0
|
9 |
+
- Conv:
|
10 |
+
args: {out_channels: 64, kernel_size: 3, stride: 2}
|
11 |
+
- RepNCSPELAN:
|
12 |
+
args: {out_channels: 128, part_channels: 128}
|
13 |
+
|
14 |
+
- AConv:
|
15 |
+
args: {out_channels: 240}
|
16 |
+
- RepNCSPELAN:
|
17 |
+
args: {out_channels: 240, part_channels: 240}
|
18 |
+
tags: B3
|
19 |
+
|
20 |
+
- AConv:
|
21 |
+
args: {out_channels: 360}
|
22 |
+
- RepNCSPELAN:
|
23 |
+
args: {out_channels: 360, part_channels: 360}
|
24 |
+
tags: B4
|
25 |
+
|
26 |
+
- AConv:
|
27 |
+
args: {out_channels: 480}
|
28 |
+
- RepNCSPELAN:
|
29 |
+
args: {out_channels: 480, part_channels: 480}
|
30 |
+
tags: B5
|
31 |
+
|
32 |
+
neck:
|
33 |
+
- SPPELAN:
|
34 |
+
args: {out_channels: 480}
|
35 |
+
tags: N3
|
36 |
+
|
37 |
+
- UpSample:
|
38 |
+
args: {scale_factor: 2, mode: nearest}
|
39 |
+
- Concat:
|
40 |
+
source: [-1, B4]
|
41 |
+
- RepNCSPELAN:
|
42 |
+
args: {out_channels: 360, part_channels: 360}
|
43 |
+
tags: N4
|
44 |
+
|
45 |
+
- UpSample:
|
46 |
+
args: {scale_factor: 2, mode: nearest}
|
47 |
+
- Concat:
|
48 |
+
source: [-1, B3]
|
49 |
+
|
50 |
+
head:
|
51 |
+
- RepNCSPELAN:
|
52 |
+
args: {out_channels: 240, part_channels: 240}
|
53 |
+
tags: P3
|
54 |
+
|
55 |
+
- AConv:
|
56 |
+
args: {out_channels: 184}
|
57 |
+
- Concat:
|
58 |
+
source: [-1, N4]
|
59 |
+
- RepNCSPELAN:
|
60 |
+
args: {out_channels: 360, part_channels: 360}
|
61 |
+
tags: P4
|
62 |
+
|
63 |
+
- AConv:
|
64 |
+
args: {out_channels: 240}
|
65 |
+
- Concat:
|
66 |
+
source: [-1, N3]
|
67 |
+
- RepNCSPELAN:
|
68 |
+
args: {out_channels: 480, part_channels: 480}
|
69 |
+
tags: P5
|
70 |
+
|
71 |
+
detection:
|
72 |
+
- MultiheadDetection:
|
73 |
+
source: [P3, P4, P5]
|
74 |
+
tags: Main
|
75 |
+
args:
|
76 |
+
reg_max: ${model.anchor.reg_max}
|
77 |
+
output: True
|
78 |
+
|
79 |
+
auxiliary:
|
80 |
+
- CBLinear:
|
81 |
+
source: B3
|
82 |
+
args: {out_channels: [240]}
|
83 |
+
tags: R3
|
84 |
+
- CBLinear:
|
85 |
+
source: B4
|
86 |
+
args: {out_channels: [240, 360]}
|
87 |
+
tags: R4
|
88 |
+
- CBLinear:
|
89 |
+
source: B5
|
90 |
+
args: {out_channels: [240, 360, 480]}
|
91 |
+
tags: R5
|
92 |
+
|
93 |
+
- Conv:
|
94 |
+
args: {out_channels: 32, kernel_size: 3, stride: 2}
|
95 |
+
source: 0
|
96 |
+
- Conv:
|
97 |
+
args: {out_channels: 64, kernel_size: 3, stride: 2}
|
98 |
+
- RepNCSPELAN:
|
99 |
+
args: {out_channels: 128, part_channels: 128}
|
100 |
+
|
101 |
+
- AConv:
|
102 |
+
args: {out_channels: 240}
|
103 |
+
- CBFuse:
|
104 |
+
source: [R3, R4, R5, -1]
|
105 |
+
args: {index: [0, 0, 0]}
|
106 |
+
- RepNCSPELAN:
|
107 |
+
args: {out_channels: 240, part_channels: 240}
|
108 |
+
tags: A3
|
109 |
+
|
110 |
+
- AConv:
|
111 |
+
args: {out_channels: 360}
|
112 |
+
- CBFuse:
|
113 |
+
source: [R4, R5, -1]
|
114 |
+
args: {index: [1, 1]}
|
115 |
+
- RepNCSPELAN:
|
116 |
+
args: {out_channels: 360, part_channels: 360}
|
117 |
+
tags: A4
|
118 |
+
|
119 |
+
- AConv:
|
120 |
+
args: {out_channels: 480}
|
121 |
+
- CBFuse:
|
122 |
+
source: [R5, -1]
|
123 |
+
args: {index: [2]}
|
124 |
+
- RepNCSPELAN:
|
125 |
+
args: {out_channels: 480, part_channels: 480}
|
126 |
+
tags: A5
|
127 |
+
|
128 |
+
- MultiheadDetection:
|
129 |
+
source: [A3, A4, A5]
|
130 |
+
tags: AUX
|
131 |
+
args:
|
132 |
+
reg_max: ${model.anchor.reg_max}
|
133 |
+
output: True
|
yolo/config/model/v9-s.yaml
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
anchor:
|
2 |
+
reg_max: 16
|
3 |
+
|
4 |
+
model:
|
5 |
+
backbone:
|
6 |
+
- Conv:
|
7 |
+
args: {out_channels: 32, kernel_size: 3, stride: 2}
|
8 |
+
source: 0
|
9 |
+
- Conv:
|
10 |
+
args: {out_channels: 64, kernel_size: 3, stride: 2}
|
11 |
+
- ELAN:
|
12 |
+
args: {out_channels: 64, part_channels: 64}
|
13 |
+
|
14 |
+
- AConv:
|
15 |
+
args: {out_channels: 128}
|
16 |
+
- RepNCSPELAN:
|
17 |
+
args:
|
18 |
+
out_channels: 128
|
19 |
+
part_channels: 128
|
20 |
+
csp_args: {repeat_num: 3}
|
21 |
+
tags: B3 # 18
|
22 |
+
|
23 |
+
- AConv:
|
24 |
+
args: {out_channels: 192}
|
25 |
+
- RepNCSPELAN:
|
26 |
+
args:
|
27 |
+
out_channels: 192
|
28 |
+
part_channels: 192
|
29 |
+
csp_args: {repeat_num: 3}
|
30 |
+
tags: B4
|
31 |
+
|
32 |
+
- AConv:
|
33 |
+
args: {out_channels: 256}
|
34 |
+
- RepNCSPELAN:
|
35 |
+
args:
|
36 |
+
out_channels: 256
|
37 |
+
part_channels: 256
|
38 |
+
csp_args: {repeat_num: 3}
|
39 |
+
tags: B5
|
40 |
+
|
41 |
+
neck:
|
42 |
+
- SPPELAN:
|
43 |
+
args: {out_channels: 256}
|
44 |
+
tags: N3
|
45 |
+
|
46 |
+
- UpSample:
|
47 |
+
args: {scale_factor: 2, mode: nearest}
|
48 |
+
- Concat:
|
49 |
+
source: [-1, B4]
|
50 |
+
- RepNCSPELAN:
|
51 |
+
args:
|
52 |
+
out_channels: 192
|
53 |
+
part_channels: 192
|
54 |
+
csp_args: {repeat_num: 3}
|
55 |
+
tags: N4
|
56 |
+
|
57 |
+
- UpSample:
|
58 |
+
args: {scale_factor: 2, mode: nearest}
|
59 |
+
- Concat:
|
60 |
+
source: [-1, B3]
|
61 |
+
|
62 |
+
- RepNCSPELAN:
|
63 |
+
args:
|
64 |
+
out_channels: 128
|
65 |
+
part_channels: 128
|
66 |
+
csp_args: {repeat_num: 3}
|
67 |
+
tags: P3
|
68 |
+
- AConv:
|
69 |
+
args: {out_channels: 96}
|
70 |
+
- Concat:
|
71 |
+
source: [-1, N4]
|
72 |
+
|
73 |
+
- RepNCSPELAN:
|
74 |
+
args:
|
75 |
+
out_channels: 192
|
76 |
+
part_channels: 192
|
77 |
+
csp_args: {repeat_num: 3}
|
78 |
+
tags: P4
|
79 |
+
- AConv:
|
80 |
+
args: {out_channels: 128}
|
81 |
+
- Concat:
|
82 |
+
source: [-1, N3]
|
83 |
+
|
84 |
+
- RepNCSPELAN:
|
85 |
+
args:
|
86 |
+
out_channels: 256
|
87 |
+
part_channels: 256
|
88 |
+
csp_args: {repeat_num: 3}
|
89 |
+
tags: P5
|
90 |
+
|
91 |
+
detection:
|
92 |
+
- MultiheadDetection:
|
93 |
+
source: [P3, P4, P5]
|
94 |
+
tags: Main
|
95 |
+
args:
|
96 |
+
reg_max: ${model.anchor.reg_max}
|
97 |
+
output: True
|
98 |
+
|
99 |
+
head:
|
100 |
+
- SPPELAN:
|
101 |
+
source: B5
|
102 |
+
args: {out_channels: 256}
|
103 |
+
tags: A5
|
104 |
+
|
105 |
+
- UpSample:
|
106 |
+
args: {scale_factor: 2, mode: nearest}
|
107 |
+
- Concat:
|
108 |
+
source: [-1, B4]
|
109 |
+
|
110 |
+
- RepNCSPELAN:
|
111 |
+
args:
|
112 |
+
out_channels: 192
|
113 |
+
part_channels: 192
|
114 |
+
csp_args: {repeat_num: 3}
|
115 |
+
tags: A4
|
116 |
+
|
117 |
+
- UpSample:
|
118 |
+
args: {scale_factor: 2, mode: nearest}
|
119 |
+
- Concat:
|
120 |
+
source: [-1, B3]
|
121 |
+
|
122 |
+
- RepNCSPELAN:
|
123 |
+
args:
|
124 |
+
out_channels: 128
|
125 |
+
part_channels: 128
|
126 |
+
csp_args: {repeat_num: 3}
|
127 |
+
tags: A3
|
128 |
+
|
129 |
+
- MultiheadDetection:
|
130 |
+
source: [A3, A4, A5]
|
131 |
+
tags: AUX
|
132 |
+
args:
|
133 |
+
reg_max: ${model.anchor.reg_max}
|
134 |
+
output: True
|
yolo/lazy.py
CHANGED
@@ -13,17 +13,16 @@ from yolo.tools.data_loader import create_dataloader
|
|
13 |
from yolo.tools.solver import ModelTester, ModelTrainer
|
14 |
from yolo.utils.bounding_box_utils import Vec2Box
|
15 |
from yolo.utils.deploy_utils import FastModelLoader
|
16 |
-
from yolo.utils.logging_utils import
|
17 |
|
18 |
|
19 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
20 |
def main(cfg: Config):
|
21 |
-
|
22 |
-
save_path = validate_log_directory(cfg, exp_name=cfg.name)
|
23 |
dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
|
24 |
device = torch.device(cfg.device)
|
25 |
if getattr(cfg.task, "fast_inference", False):
|
26 |
-
model = FastModelLoader(cfg).load_model()
|
27 |
device = torch.device(cfg.device)
|
28 |
else:
|
29 |
model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight, device=device)
|
@@ -31,11 +30,11 @@ def main(cfg: Config):
|
|
31 |
vec2box = Vec2Box(model, cfg.image_size, device)
|
32 |
|
33 |
if cfg.task.task == "train":
|
34 |
-
trainer = ModelTrainer(cfg, model, vec2box,
|
35 |
trainer.solve(dataloader)
|
36 |
|
37 |
if cfg.task.task == "inference":
|
38 |
-
tester = ModelTester(cfg, model, vec2box,
|
39 |
tester.solve(dataloader)
|
40 |
|
41 |
|
|
|
13 |
from yolo.tools.solver import ModelTester, ModelTrainer
|
14 |
from yolo.utils.bounding_box_utils import Vec2Box
|
15 |
from yolo.utils.deploy_utils import FastModelLoader
|
16 |
+
from yolo.utils.logging_utils import ProgressLogger
|
17 |
|
18 |
|
19 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
20 |
def main(cfg: Config):
|
21 |
+
progress = ProgressLogger(cfg, exp_name=cfg.name)
|
|
|
22 |
dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
|
23 |
device = torch.device(cfg.device)
|
24 |
if getattr(cfg.task, "fast_inference", False):
|
25 |
+
model = FastModelLoader(cfg, device).load_model()
|
26 |
device = torch.device(cfg.device)
|
27 |
else:
|
28 |
model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight, device=device)
|
|
|
30 |
vec2box = Vec2Box(model, cfg.image_size, device)
|
31 |
|
32 |
if cfg.task.task == "train":
|
33 |
+
trainer = ModelTrainer(cfg, model, vec2box, progress, device)
|
34 |
trainer.solve(dataloader)
|
35 |
|
36 |
if cfg.task.task == "inference":
|
37 |
+
tester = ModelTester(cfg, model, vec2box, progress, device)
|
38 |
tester.solve(dataloader)
|
39 |
|
40 |
|
yolo/model/module.py
CHANGED
@@ -192,6 +192,36 @@ class RepNCSP(nn.Module):
|
|
192 |
return self.conv3(torch.cat((x1, x2), dim=1))
|
193 |
|
194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
class RepNCSPELAN(nn.Module):
|
196 |
"""RepNCSPELAN block combining RepNCSP blocks with ELAN structure."""
|
197 |
|
@@ -230,6 +260,21 @@ class RepNCSPELAN(nn.Module):
|
|
230 |
return x5
|
231 |
|
232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
class ADown(nn.Module):
|
234 |
"""Downsampling module combining average and max pooling with convolution for feature reduction."""
|
235 |
|
@@ -498,26 +543,6 @@ class CSPDark(nn.Module):
|
|
498 |
return self.cv2(torch.cat((self.cb(y[0]), y[1]), 1))
|
499 |
|
500 |
|
501 |
-
# ELAN
|
502 |
-
class ELAN(nn.Module):
|
503 |
-
# ELAN
|
504 |
-
def __init__(self, in_channels, out_channels, med_channels, elan_repeat=2, cb_repeat=2, ratio=1.0):
|
505 |
-
|
506 |
-
super().__init__()
|
507 |
-
|
508 |
-
h_channels = med_channels // 2
|
509 |
-
self.cv1 = Conv(in_channels, med_channels, 1, 1)
|
510 |
-
self.cb = nn.ModuleList(ConvBlock(h_channels, repeat=cb_repeat, ratio=ratio) for _ in range(elan_repeat))
|
511 |
-
self.cv2 = Conv((2 + elan_repeat) * h_channels, out_channels, 1, 1)
|
512 |
-
|
513 |
-
def forward(self, x):
|
514 |
-
|
515 |
-
y = list(self.cv1(x).chunk(2, 1))
|
516 |
-
y.extend((m(y[-1])) for m in self.cb)
|
517 |
-
|
518 |
-
return self.cv2(torch.cat(y, 1))
|
519 |
-
|
520 |
-
|
521 |
class CSPELAN(nn.Module):
|
522 |
# ELAN
|
523 |
def __init__(self, in_channels, out_channels, med_channels, elan_repeat=2, cb_repeat=2, ratio=1.0):
|
|
|
192 |
return self.conv3(torch.cat((x1, x2), dim=1))
|
193 |
|
194 |
|
195 |
+
class ELAN(nn.Module):
|
196 |
+
"""ELAN structure."""
|
197 |
+
|
198 |
+
def __init__(
|
199 |
+
self,
|
200 |
+
in_channels: int,
|
201 |
+
out_channels: int,
|
202 |
+
part_channels: int,
|
203 |
+
*,
|
204 |
+
process_channels: Optional[int] = None,
|
205 |
+
**kwargs,
|
206 |
+
):
|
207 |
+
super().__init__()
|
208 |
+
|
209 |
+
if process_channels is None:
|
210 |
+
process_channels = part_channels // 2
|
211 |
+
|
212 |
+
self.conv1 = Conv(in_channels, part_channels, 1, **kwargs)
|
213 |
+
self.conv2 = Conv(part_channels // 2, process_channels, 3, padding=1, **kwargs)
|
214 |
+
self.conv3 = Conv(process_channels, process_channels, 3, padding=1, **kwargs)
|
215 |
+
self.conv4 = Conv(part_channels + 2 * process_channels, out_channels, 1, **kwargs)
|
216 |
+
|
217 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
218 |
+
x1, x2 = self.conv1(x).chunk(2, 1)
|
219 |
+
x3 = self.conv2(x2)
|
220 |
+
x4 = self.conv3(x3)
|
221 |
+
x5 = self.conv4(torch.cat([x1, x2, x3, x4], dim=1))
|
222 |
+
return x5
|
223 |
+
|
224 |
+
|
225 |
class RepNCSPELAN(nn.Module):
|
226 |
"""RepNCSPELAN block combining RepNCSP blocks with ELAN structure."""
|
227 |
|
|
|
260 |
return x5
|
261 |
|
262 |
|
263 |
+
class AConv(nn.Module):
|
264 |
+
"""Downsampling module combining average and max pooling with convolution for feature reduction."""
|
265 |
+
|
266 |
+
def __init__(self, in_channels: int, out_channels: int):
|
267 |
+
super().__init__()
|
268 |
+
mid_layer = {"kernel_size": 3, "stride": 2}
|
269 |
+
self.avg_pool = Pool("avg", kernel_size=2, stride=1)
|
270 |
+
self.conv = Conv(in_channels, out_channels, **mid_layer)
|
271 |
+
|
272 |
+
def forward(self, x: Tensor) -> Tensor:
|
273 |
+
x = self.avg_pool(x)
|
274 |
+
x = self.conv(x)
|
275 |
+
return x
|
276 |
+
|
277 |
+
|
278 |
class ADown(nn.Module):
|
279 |
"""Downsampling module combining average and max pooling with convolution for feature reduction."""
|
280 |
|
|
|
543 |
return self.cv2(torch.cat((self.cb(y[0]), y[1]), 1))
|
544 |
|
545 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
546 |
class CSPELAN(nn.Module):
|
547 |
# ELAN
|
548 |
def __init__(self, in_channels, out_channels, med_channels, elan_repeat=2, cb_repeat=2, ratio=1.0):
|
yolo/model/yolo.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import os
|
2 |
-
from typing import
|
3 |
|
4 |
import torch
|
5 |
from loguru import logger
|
@@ -43,7 +43,7 @@ class YOLO(nn.Module):
|
|
43 |
source = self.get_source_idx(layer_info.get("source", -1), layer_idx)
|
44 |
|
45 |
# Find in channels
|
46 |
-
if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "CBLinear"]):
|
47 |
layer_args["in_channels"] = output_dim[source]
|
48 |
if "Detection" in layer_type:
|
49 |
layer_args["in_channels"] = [output_dim[idx] for idx in source]
|
@@ -81,7 +81,7 @@ class YOLO(nn.Module):
|
|
81 |
return output
|
82 |
|
83 |
def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
|
84 |
-
if any(module in layer_type for module in ["Conv", "ELAN", "ADown"]):
|
85 |
return layer_args["out_channels"]
|
86 |
if layer_type == "CBFuse":
|
87 |
return output_dim[source[-1]]
|
@@ -117,9 +117,7 @@ class YOLO(nn.Module):
|
|
117 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
118 |
|
119 |
|
120 |
-
def create_model(
|
121 |
-
model_cfg: ModelConfig, class_num: int = 80, weight_path: str = "weights/v9-c.pt", device: device = device("cuda")
|
122 |
-
) -> YOLO:
|
123 |
"""Constructs and returns a model from a Dictionary configuration file.
|
124 |
|
125 |
Args:
|
@@ -135,8 +133,9 @@ def create_model(
|
|
135 |
if not os.path.exists(weight_path):
|
136 |
logger.info(f"π Weight {weight_path} not found, try downloading")
|
137 |
prepare_weight(weight_path=weight_path)
|
138 |
-
|
139 |
-
|
|
|
140 |
|
141 |
log_model_structure(model.model)
|
142 |
draw_model(model=model)
|
|
|
1 |
import os
|
2 |
+
from typing import Dict, List, Optional, Union
|
3 |
|
4 |
import torch
|
5 |
from loguru import logger
|
|
|
43 |
source = self.get_source_idx(layer_info.get("source", -1), layer_idx)
|
44 |
|
45 |
# Find in channels
|
46 |
+
if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "AConv", "CBLinear"]):
|
47 |
layer_args["in_channels"] = output_dim[source]
|
48 |
if "Detection" in layer_type:
|
49 |
layer_args["in_channels"] = [output_dim[idx] for idx in source]
|
|
|
81 |
return output
|
82 |
|
83 |
def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
|
84 |
+
if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "AConv"]):
|
85 |
return layer_args["out_channels"]
|
86 |
if layer_type == "CBFuse":
|
87 |
return output_dim[source[-1]]
|
|
|
117 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
118 |
|
119 |
|
120 |
+
def create_model(model_cfg: ModelConfig, weight_path: Optional[str], device: device, class_num: int = 80) -> YOLO:
|
|
|
|
|
121 |
"""Constructs and returns a model from a Dictionary configuration file.
|
122 |
|
123 |
Args:
|
|
|
133 |
if not os.path.exists(weight_path):
|
134 |
logger.info(f"π Weight {weight_path} not found, try downloading")
|
135 |
prepare_weight(weight_path=weight_path)
|
136 |
+
if os.path.exists(weight_path):
|
137 |
+
model.model.load_state_dict(torch.load(weight_path, map_location=device), strict=False)
|
138 |
+
logger.info("β
Success load model weight")
|
139 |
|
140 |
log_model_structure(model.model)
|
141 |
draw_model(model=model)
|
yolo/tools/data_loader.py
CHANGED
@@ -23,7 +23,6 @@ from yolo.tools.data_augmentation import (
|
|
23 |
VerticalFlip,
|
24 |
)
|
25 |
from yolo.tools.dataset_preparation import prepare_dataset
|
26 |
-
from yolo.tools.drawer import draw_bboxes
|
27 |
from yolo.utils.dataset_utils import (
|
28 |
create_image_metadata,
|
29 |
locate_label_paths,
|
@@ -204,7 +203,7 @@ def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: st
|
|
204 |
return StreamDataLoader(data_cfg)
|
205 |
|
206 |
if dataset_cfg.auto_download:
|
207 |
-
prepare_dataset(dataset_cfg)
|
208 |
|
209 |
return YoloDataLoader(data_cfg, dataset_cfg, task)
|
210 |
|
|
|
23 |
VerticalFlip,
|
24 |
)
|
25 |
from yolo.tools.dataset_preparation import prepare_dataset
|
|
|
26 |
from yolo.utils.dataset_utils import (
|
27 |
create_image_metadata,
|
28 |
locate_label_paths,
|
|
|
203 |
return StreamDataLoader(data_cfg)
|
204 |
|
205 |
if dataset_cfg.auto_download:
|
206 |
+
prepare_dataset(dataset_cfg, task)
|
207 |
|
208 |
return YoloDataLoader(data_cfg, dataset_cfg, task)
|
209 |
|
yolo/tools/dataset_preparation.py
CHANGED
@@ -52,7 +52,7 @@ def check_files(directory, expected_count=None):
|
|
52 |
return len(files) == expected_count if expected_count is not None else bool(files)
|
53 |
|
54 |
|
55 |
-
def prepare_dataset(dataset_cfg: DatasetConfig):
|
56 |
"""
|
57 |
Prepares dataset by downloading and unzipping if necessary.
|
58 |
"""
|
@@ -60,8 +60,8 @@ def prepare_dataset(dataset_cfg: DatasetConfig):
|
|
60 |
for data_type, settings in dataset_cfg.auto_download.items():
|
61 |
base_url = settings["base_url"]
|
62 |
for dataset_type, dataset_args in settings.items():
|
63 |
-
if dataset_type
|
64 |
-
continue
|
65 |
file_name = f"{dataset_args.get('file_name', dataset_type)}.zip"
|
66 |
url = f"{base_url}{file_name}"
|
67 |
local_zip_path = os.path.join(data_dir, file_name)
|
|
|
52 |
return len(files) == expected_count if expected_count is not None else bool(files)
|
53 |
|
54 |
|
55 |
+
def prepare_dataset(dataset_cfg: DatasetConfig, task: str):
|
56 |
"""
|
57 |
Prepares dataset by downloading and unzipping if necessary.
|
58 |
"""
|
|
|
60 |
for data_type, settings in dataset_cfg.auto_download.items():
|
61 |
base_url = settings["base_url"]
|
62 |
for dataset_type, dataset_args in settings.items():
|
63 |
+
if dataset_type != "annotations" and dataset_cfg.get(task, task) != dataset_type:
|
64 |
+
continue
|
65 |
file_name = f"{dataset_args.get('file_name', dataset_type)}.zip"
|
66 |
url = f"{base_url}{file_name}"
|
67 |
local_zip_path = os.path.join(data_dir, file_name)
|
yolo/tools/solver.py
CHANGED
@@ -14,7 +14,7 @@ from yolo.tools.data_loader import StreamDataLoader, create_dataloader
|
|
14 |
from yolo.tools.drawer import draw_bboxes
|
15 |
from yolo.tools.loss_functions import create_loss_function
|
16 |
from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms, calculate_map
|
17 |
-
from yolo.utils.logging_utils import
|
18 |
from yolo.utils.model_utils import (
|
19 |
ExponentialMovingAverage,
|
20 |
create_optimizer,
|
@@ -23,7 +23,7 @@ from yolo.utils.model_utils import (
|
|
23 |
|
24 |
|
25 |
class ModelTrainer:
|
26 |
-
def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box,
|
27 |
train_cfg: TrainConfig = cfg.task
|
28 |
self.model = model
|
29 |
self.vec2box = vec2box
|
@@ -31,11 +31,11 @@ class ModelTrainer:
|
|
31 |
self.optimizer = create_optimizer(model, train_cfg.optimizer)
|
32 |
self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
|
33 |
self.loss_fn = create_loss_function(cfg, vec2box)
|
34 |
-
self.progress =
|
35 |
self.num_epochs = cfg.task.epoch
|
36 |
|
37 |
self.validation_dataloader = create_dataloader(cfg.task.validation.data, cfg.dataset, cfg.task.validation.task)
|
38 |
-
self.validator = ModelValidator(cfg.task.validation, model, vec2box,
|
39 |
|
40 |
if getattr(train_cfg.ema, "enabled", False):
|
41 |
self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
|
@@ -102,14 +102,15 @@ class ModelTrainer:
|
|
102 |
|
103 |
|
104 |
class ModelTester:
|
105 |
-
def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box,
|
106 |
self.model = model
|
107 |
self.device = device
|
108 |
self.vec2box = vec2box
|
109 |
-
self.progress =
|
110 |
|
111 |
self.nms = cfg.task.nms
|
112 |
-
self.save_path = save_path
|
|
|
113 |
self.save_predict = getattr(cfg.task, "save_predict", None)
|
114 |
self.idx2label = cfg.class_list
|
115 |
|
@@ -142,8 +143,7 @@ class ModelTester:
|
|
142 |
break
|
143 |
if not self.save_predict:
|
144 |
continue
|
145 |
-
|
146 |
-
if self.save_predict == False:
|
147 |
save_image_path = os.path.join(self.save_path, f"frame{idx:03d}.png")
|
148 |
img.save(save_image_path)
|
149 |
logger.info(f"πΎ Saved visualize image at {save_image_path}")
|
@@ -164,16 +164,13 @@ class ModelValidator:
|
|
164 |
validation_cfg: ValidationConfig,
|
165 |
model: YOLO,
|
166 |
vec2box: Vec2Box,
|
167 |
-
save_path: str,
|
168 |
device,
|
169 |
-
|
170 |
-
progress: ProgressTracker,
|
171 |
):
|
172 |
self.model = model
|
173 |
self.vec2box = vec2box
|
174 |
self.device = device
|
175 |
self.progress = progress
|
176 |
-
self.save_path = save_path
|
177 |
|
178 |
self.nms = validation_cfg.nms
|
179 |
|
|
|
14 |
from yolo.tools.drawer import draw_bboxes
|
15 |
from yolo.tools.loss_functions import create_loss_function
|
16 |
from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms, calculate_map
|
17 |
+
from yolo.utils.logging_utils import ProgressLogger
|
18 |
from yolo.utils.model_utils import (
|
19 |
ExponentialMovingAverage,
|
20 |
create_optimizer,
|
|
|
23 |
|
24 |
|
25 |
class ModelTrainer:
|
26 |
+
def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device):
|
27 |
train_cfg: TrainConfig = cfg.task
|
28 |
self.model = model
|
29 |
self.vec2box = vec2box
|
|
|
31 |
self.optimizer = create_optimizer(model, train_cfg.optimizer)
|
32 |
self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
|
33 |
self.loss_fn = create_loss_function(cfg, vec2box)
|
34 |
+
self.progress = progress
|
35 |
self.num_epochs = cfg.task.epoch
|
36 |
|
37 |
self.validation_dataloader = create_dataloader(cfg.task.validation.data, cfg.dataset, cfg.task.validation.task)
|
38 |
+
self.validator = ModelValidator(cfg.task.validation, model, vec2box, progress, device, self.progress)
|
39 |
|
40 |
if getattr(train_cfg.ema, "enabled", False):
|
41 |
self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
|
|
|
102 |
|
103 |
|
104 |
class ModelTester:
|
105 |
+
def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device):
|
106 |
self.model = model
|
107 |
self.device = device
|
108 |
self.vec2box = vec2box
|
109 |
+
self.progress = progress
|
110 |
|
111 |
self.nms = cfg.task.nms
|
112 |
+
self.save_path = os.path.join(progress.save_path, "images")
|
113 |
+
os.makedirs(self.save_path, exist_ok=True)
|
114 |
self.save_predict = getattr(cfg.task, "save_predict", None)
|
115 |
self.idx2label = cfg.class_list
|
116 |
|
|
|
143 |
break
|
144 |
if not self.save_predict:
|
145 |
continue
|
146 |
+
if self.save_predict != False:
|
|
|
147 |
save_image_path = os.path.join(self.save_path, f"frame{idx:03d}.png")
|
148 |
img.save(save_image_path)
|
149 |
logger.info(f"πΎ Saved visualize image at {save_image_path}")
|
|
|
164 |
validation_cfg: ValidationConfig,
|
165 |
model: YOLO,
|
166 |
vec2box: Vec2Box,
|
|
|
167 |
device,
|
168 |
+
progress: ProgressLogger,
|
|
|
169 |
):
|
170 |
self.model = model
|
171 |
self.vec2box = vec2box
|
172 |
self.device = device
|
173 |
self.progress = progress
|
|
|
174 |
|
175 |
self.nms = validation_cfg.nms
|
176 |
|
yolo/utils/logging_utils.py
CHANGED
@@ -38,15 +38,18 @@ def custom_logger(quite: bool = False):
|
|
38 |
)
|
39 |
|
40 |
|
41 |
-
class
|
42 |
-
def __init__(self,
|
|
|
|
|
|
|
43 |
self.progress = Progress(
|
44 |
TextColumn("[progress.description]{task.description}"),
|
45 |
BarColumn(bar_width=None),
|
46 |
TextColumn("{task.completed:.0f}/{task.total:.0f}"),
|
47 |
TimeRemainingColumn(),
|
48 |
)
|
49 |
-
self.use_wandb = use_wandb
|
50 |
if self.use_wandb:
|
51 |
wandb.errors.term._log = custom_wandb_log
|
52 |
self.wandb = wandb.init(
|
|
|
38 |
)
|
39 |
|
40 |
|
41 |
+
class ProgressLogger:
|
42 |
+
def __init__(self, cfg: Config, exp_name: str):
|
43 |
+
custom_logger(getattr(cfg, "quite", False))
|
44 |
+
self.save_path = validate_log_directory(cfg, exp_name=cfg.name)
|
45 |
+
|
46 |
self.progress = Progress(
|
47 |
TextColumn("[progress.description]{task.description}"),
|
48 |
BarColumn(bar_width=None),
|
49 |
TextColumn("{task.completed:.0f}/{task.total:.0f}"),
|
50 |
TimeRemainingColumn(),
|
51 |
)
|
52 |
+
self.use_wandb = cfg.use_wandb
|
53 |
if self.use_wandb:
|
54 |
wandb.errors.term._log = custom_wandb_log
|
55 |
self.wandb = wandb.init(
|