π [Update] some bug or vaiable name in Vec2Box
Browse files- yolo/__init__.py +2 -1
- yolo/config/model/v9-c.yaml +8 -8
- yolo/tools/format_converters.py +4 -2
- yolo/tools/solver.py +21 -16
- yolo/utils/bounding_box_utils.py +2 -0
yolo/__init__.py
CHANGED
@@ -3,7 +3,7 @@ from yolo.model.yolo import create_model
|
|
3 |
from yolo.tools.data_loader import AugmentationComposer, create_dataloader
|
4 |
from yolo.tools.drawer import draw_bboxes
|
5 |
from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
|
6 |
-
from yolo.utils.bounding_box_utils import bbox_nms
|
7 |
from yolo.utils.deploy_utils import FastModelLoader
|
8 |
from yolo.utils.logging_utils import custom_logger
|
9 |
|
@@ -13,6 +13,7 @@ all = [
|
|
13 |
"custom_logger",
|
14 |
"validate_log_directory",
|
15 |
"draw_bboxes",
|
|
|
16 |
"bbox_nms",
|
17 |
"AugmentationComposer",
|
18 |
"create_dataloader",
|
|
|
3 |
from yolo.tools.data_loader import AugmentationComposer, create_dataloader
|
4 |
from yolo.tools.drawer import draw_bboxes
|
5 |
from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
|
6 |
+
from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms
|
7 |
from yolo.utils.deploy_utils import FastModelLoader
|
8 |
from yolo.utils.logging_utils import custom_logger
|
9 |
|
|
|
13 |
"custom_logger",
|
14 |
"validate_log_directory",
|
15 |
"draw_bboxes",
|
16 |
+
"Vec2Box",
|
17 |
"bbox_nms",
|
18 |
"AugmentationComposer",
|
19 |
"create_dataloader",
|
yolo/config/model/v9-c.yaml
CHANGED
@@ -68,6 +68,14 @@ model:
|
|
68 |
args: {out_channels: 512, part_channels: 512}
|
69 |
tags: P5
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
auxiliary:
|
72 |
- CBLinear:
|
73 |
source: B3
|
@@ -123,11 +131,3 @@ model:
|
|
123 |
args:
|
124 |
reg_max: ${model.anchor.reg_max}
|
125 |
output: True
|
126 |
-
|
127 |
-
detection:
|
128 |
-
- MultiheadDetection:
|
129 |
-
source: [P3, P4, P5]
|
130 |
-
tags: Main
|
131 |
-
args:
|
132 |
-
reg_max: ${model.anchor.reg_max}
|
133 |
-
output: True
|
|
|
68 |
args: {out_channels: 512, part_channels: 512}
|
69 |
tags: P5
|
70 |
|
71 |
+
detection:
|
72 |
+
- MultiheadDetection:
|
73 |
+
source: [P3, P4, P5]
|
74 |
+
tags: Main
|
75 |
+
args:
|
76 |
+
reg_max: ${model.anchor.reg_max}
|
77 |
+
output: True
|
78 |
+
|
79 |
auxiliary:
|
80 |
- CBLinear:
|
81 |
source: B3
|
|
|
131 |
args:
|
132 |
reg_max: ${model.anchor.reg_max}
|
133 |
output: True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
yolo/tools/format_converters.py
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
def convert_weight(old_state_dict, new_state_dict, model_size: int = 38):
|
2 |
# TODO: need to refactor
|
|
|
3 |
for idx in range(model_size):
|
4 |
new_list, old_list = [], []
|
5 |
for weight_name, weight_value in new_state_dict.items():
|
6 |
if weight_name.split(".")[0] == str(idx):
|
7 |
new_list.append((weight_name, None))
|
8 |
for weight_name, weight_value in old_state_dict.items():
|
9 |
-
if f"model.{idx+
|
10 |
old_list.append((weight_name, weight_value))
|
11 |
if len(new_list) == len(old_list):
|
12 |
for (weight_name, _), (_, weight_value) in zip(new_list, old_list):
|
@@ -17,7 +18,8 @@ def convert_weight(old_state_dict, new_state_dict, model_size: int = 38):
|
|
17 |
continue
|
18 |
_, _, conv_name, conv_idx, *details = weight_name.split(".")
|
19 |
if conv_name == "cv4" or conv_name == "cv5":
|
20 |
-
layer_idx =
|
|
|
21 |
else:
|
22 |
layer_idx = 37
|
23 |
|
|
|
1 |
def convert_weight(old_state_dict, new_state_dict, model_size: int = 38):
|
2 |
# TODO: need to refactor
|
3 |
+
shift = 1
|
4 |
for idx in range(model_size):
|
5 |
new_list, old_list = [], []
|
6 |
for weight_name, weight_value in new_state_dict.items():
|
7 |
if weight_name.split(".")[0] == str(idx):
|
8 |
new_list.append((weight_name, None))
|
9 |
for weight_name, weight_value in old_state_dict.items():
|
10 |
+
if f"model.{idx+shift}." in weight_name:
|
11 |
old_list.append((weight_name, weight_value))
|
12 |
if len(new_list) == len(old_list):
|
13 |
for (weight_name, _), (_, weight_value) in zip(new_list, old_list):
|
|
|
18 |
continue
|
19 |
_, _, conv_name, conv_idx, *details = weight_name.split(".")
|
20 |
if conv_name == "cv4" or conv_name == "cv5":
|
21 |
+
layer_idx = 22
|
22 |
+
shift = 2
|
23 |
else:
|
24 |
layer_idx = 37
|
25 |
|
yolo/tools/solver.py
CHANGED
@@ -32,7 +32,7 @@ class ModelTrainer:
|
|
32 |
self.num_epochs = cfg.task.epoch
|
33 |
|
34 |
self.validation_dataloader = create_dataloader(cfg.task.validation.data, cfg.dataset, cfg.task.validation.task)
|
35 |
-
self.validator = ModelValidator(cfg.task.validation, model, save_path, device, self.progress)
|
36 |
|
37 |
if getattr(train_cfg.ema, "enabled", False):
|
38 |
self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
|
@@ -40,14 +40,14 @@ class ModelTrainer:
|
|
40 |
self.ema = None
|
41 |
self.scaler = GradScaler()
|
42 |
|
43 |
-
def train_one_batch(self,
|
44 |
-
|
45 |
self.optimizer.zero_grad()
|
46 |
|
47 |
with autocast():
|
48 |
-
|
49 |
-
aux_predicts = self.vec2box(
|
50 |
-
main_predicts = self.vec2box(
|
51 |
loss, loss_item = self.loss_fn(aux_predicts, main_predicts, targets)
|
52 |
|
53 |
self.scaler.scale(loss).backward()
|
@@ -60,8 +60,8 @@ class ModelTrainer:
|
|
60 |
self.model.train()
|
61 |
total_loss = 0
|
62 |
|
63 |
-
for
|
64 |
-
loss, loss_each = self.train_one_batch(
|
65 |
|
66 |
total_loss += loss
|
67 |
self.progress.one_batch(loss_each)
|
@@ -111,14 +111,15 @@ class ModelTester:
|
|
111 |
|
112 |
def solve(self, dataloader: StreamDataLoader):
|
113 |
logger.info("π Start Inference!")
|
114 |
-
|
|
|
115 |
try:
|
116 |
for idx, images in enumerate(dataloader):
|
117 |
images = images.to(self.device)
|
118 |
with torch.no_grad():
|
119 |
-
|
120 |
-
|
121 |
-
nms_out = bbox_nms(
|
122 |
draw_bboxes(
|
123 |
images[0],
|
124 |
nms_out[0],
|
@@ -141,15 +142,18 @@ class ModelValidator:
|
|
141 |
self,
|
142 |
validation_cfg: ValidationConfig,
|
143 |
model: YOLO,
|
|
|
144 |
save_path: str,
|
145 |
device,
|
146 |
# TODO: think Progress?
|
147 |
progress: ProgressTracker,
|
148 |
):
|
149 |
self.model = model
|
|
|
150 |
self.device = device
|
151 |
self.progress = progress
|
152 |
self.save_path = save_path
|
|
|
153 |
self.nms = validation_cfg.nms
|
154 |
|
155 |
def solve(self, dataloader):
|
@@ -159,11 +163,12 @@ class ModelValidator:
|
|
159 |
iou_thresholds = torch.arange(0.5, 1.0, 0.05)
|
160 |
map_all = []
|
161 |
self.progress.start_one_epoch(len(dataloader))
|
162 |
-
for
|
163 |
-
|
164 |
with torch.no_grad():
|
165 |
-
|
166 |
-
|
|
|
167 |
for idx, predict in enumerate(nms_out):
|
168 |
map_value = calculate_map(predict, targets[idx], iou_thresholds)
|
169 |
map_all.append(map_value[0])
|
|
|
32 |
self.num_epochs = cfg.task.epoch
|
33 |
|
34 |
self.validation_dataloader = create_dataloader(cfg.task.validation.data, cfg.dataset, cfg.task.validation.task)
|
35 |
+
self.validator = ModelValidator(cfg.task.validation, model, vec2box, save_path, device, self.progress)
|
36 |
|
37 |
if getattr(train_cfg.ema, "enabled", False):
|
38 |
self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
|
|
|
40 |
self.ema = None
|
41 |
self.scaler = GradScaler()
|
42 |
|
43 |
+
def train_one_batch(self, images: Tensor, targets: Tensor):
|
44 |
+
images, targets = images.to(self.device), targets.to(self.device)
|
45 |
self.optimizer.zero_grad()
|
46 |
|
47 |
with autocast():
|
48 |
+
predicts = self.model(images)
|
49 |
+
aux_predicts = self.vec2box(predicts["AUX"])
|
50 |
+
main_predicts = self.vec2box(predicts["Main"])
|
51 |
loss, loss_item = self.loss_fn(aux_predicts, main_predicts, targets)
|
52 |
|
53 |
self.scaler.scale(loss).backward()
|
|
|
60 |
self.model.train()
|
61 |
total_loss = 0
|
62 |
|
63 |
+
for images, targets in dataloader:
|
64 |
+
loss, loss_each = self.train_one_batch(images, targets)
|
65 |
|
66 |
total_loss += loss
|
67 |
self.progress.one_batch(loss_each)
|
|
|
111 |
|
112 |
def solve(self, dataloader: StreamDataLoader):
|
113 |
logger.info("π Start Inference!")
|
114 |
+
if isinstance(self.model, torch.nn.Module):
|
115 |
+
self.model.eval()
|
116 |
try:
|
117 |
for idx, images in enumerate(dataloader):
|
118 |
images = images.to(self.device)
|
119 |
with torch.no_grad():
|
120 |
+
predicts = self.model(images)
|
121 |
+
predicts = self.vec2box(predicts["Main"])
|
122 |
+
nms_out = bbox_nms(predicts[0], predicts[2], self.nms)
|
123 |
draw_bboxes(
|
124 |
images[0],
|
125 |
nms_out[0],
|
|
|
142 |
self,
|
143 |
validation_cfg: ValidationConfig,
|
144 |
model: YOLO,
|
145 |
+
vec2box: Vec2Box,
|
146 |
save_path: str,
|
147 |
device,
|
148 |
# TODO: think Progress?
|
149 |
progress: ProgressTracker,
|
150 |
):
|
151 |
self.model = model
|
152 |
+
self.vec2box = vec2box
|
153 |
self.device = device
|
154 |
self.progress = progress
|
155 |
self.save_path = save_path
|
156 |
+
|
157 |
self.nms = validation_cfg.nms
|
158 |
|
159 |
def solve(self, dataloader):
|
|
|
163 |
iou_thresholds = torch.arange(0.5, 1.0, 0.05)
|
164 |
map_all = []
|
165 |
self.progress.start_one_epoch(len(dataloader))
|
166 |
+
for images, targets in dataloader:
|
167 |
+
images, targets = images.to(self.device), targets.to(self.device)
|
168 |
with torch.no_grad():
|
169 |
+
predicts = self.model(images)
|
170 |
+
predicts = self.vec2box(predicts["Main"])
|
171 |
+
nms_out = bbox_nms(predicts[0], predicts[2], self.nms)
|
172 |
for idx, predict in enumerate(nms_out):
|
173 |
map_value = calculate_map(predict, targets[idx], iou_thresholds)
|
174 |
map_all.append(map_value[0])
|
yolo/utils/bounding_box_utils.py
CHANGED
@@ -4,6 +4,7 @@ from typing import List, Tuple
|
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
6 |
from einops import rearrange
|
|
|
7 |
from torch import Tensor
|
8 |
from torchvision.ops import batched_nms
|
9 |
|
@@ -264,6 +265,7 @@ class BoxMatcher:
|
|
264 |
|
265 |
class Vec2Box:
|
266 |
def __init__(self, model, image_size, device):
|
|
|
267 |
dummy_input = torch.zeros(1, 3, *image_size).to(device)
|
268 |
dummy_output = model(dummy_input)
|
269 |
anchors_num = []
|
|
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
6 |
from einops import rearrange
|
7 |
+
from loguru import logger
|
8 |
from torch import Tensor
|
9 |
from torchvision.ops import batched_nms
|
10 |
|
|
|
265 |
|
266 |
class Vec2Box:
|
267 |
def __init__(self, model, image_size, device):
|
268 |
+
logger.info("π§Έ Make a dummy test for auto-anchor size")
|
269 |
dummy_input = torch.zeros(1, 3, *image_size).to(device)
|
270 |
dummy_output = model(dummy_input)
|
271 |
anchors_num = []
|