✨ [New] use lightning framework to training!
Browse files- yolo/__init__.py +10 -7
- yolo/lazy.py +21 -26
- yolo/tools/solver.py +77 -255
- yolo/utils/bounding_box_utils.py +7 -0
- yolo/utils/logging_utils.py +207 -197
- yolo/utils/model_utils.py +2 -16
- yolo/utils/solver_utils.py +3 -2
yolo/__init__.py
CHANGED
@@ -2,18 +2,22 @@ from yolo.config.config import Config, NMSConfig
|
|
2 |
from yolo.model.yolo import create_model
|
3 |
from yolo.tools.data_loader import AugmentationComposer, create_dataloader
|
4 |
from yolo.tools.drawer import draw_bboxes
|
5 |
-
from yolo.tools.solver import
|
6 |
from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms, create_converter
|
7 |
from yolo.utils.deploy_utils import FastModelLoader
|
8 |
-
from yolo.utils.logging_utils import
|
|
|
|
|
|
|
|
|
9 |
from yolo.utils.model_utils import PostProccess
|
10 |
|
11 |
all = [
|
12 |
"create_model",
|
13 |
"Config",
|
14 |
-
"
|
15 |
"NMSConfig",
|
16 |
-
"
|
17 |
"validate_log_directory",
|
18 |
"draw_bboxes",
|
19 |
"Vec2Box",
|
@@ -21,10 +25,9 @@ all = [
|
|
21 |
"bbox_nms",
|
22 |
"create_converter",
|
23 |
"AugmentationComposer",
|
|
|
24 |
"create_dataloader",
|
25 |
"FastModelLoader",
|
26 |
-
"
|
27 |
-
"ModelTrainer",
|
28 |
-
"ModelValidator",
|
29 |
"PostProccess",
|
30 |
]
|
|
|
2 |
from yolo.model.yolo import create_model
|
3 |
from yolo.tools.data_loader import AugmentationComposer, create_dataloader
|
4 |
from yolo.tools.drawer import draw_bboxes
|
5 |
+
from yolo.tools.solver import TrainModel
|
6 |
from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms, create_converter
|
7 |
from yolo.utils.deploy_utils import FastModelLoader
|
8 |
+
from yolo.utils.logging_utils import (
|
9 |
+
ImageLogger,
|
10 |
+
YOLORichModelSummary,
|
11 |
+
YOLORichProgressBar,
|
12 |
+
)
|
13 |
from yolo.utils.model_utils import PostProccess
|
14 |
|
15 |
all = [
|
16 |
"create_model",
|
17 |
"Config",
|
18 |
+
"YOLORichProgressBar",
|
19 |
"NMSConfig",
|
20 |
+
"YOLORichModelSummary",
|
21 |
"validate_log_directory",
|
22 |
"draw_bboxes",
|
23 |
"Vec2Box",
|
|
|
25 |
"bbox_nms",
|
26 |
"create_converter",
|
27 |
"AugmentationComposer",
|
28 |
+
"ImageLogger",
|
29 |
"create_dataloader",
|
30 |
"FastModelLoader",
|
31 |
+
"TrainModel",
|
|
|
|
|
32 |
"PostProccess",
|
33 |
]
|
yolo/lazy.py
CHANGED
@@ -2,41 +2,36 @@ import sys
|
|
2 |
from pathlib import Path
|
3 |
|
4 |
import hydra
|
|
|
5 |
|
6 |
project_root = Path(__file__).resolve().parent.parent
|
7 |
sys.path.append(str(project_root))
|
8 |
|
9 |
from yolo.config.config import Config
|
10 |
-
from yolo.
|
11 |
-
from yolo.
|
12 |
-
from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
|
13 |
-
from yolo.utils.bounding_box_utils import create_converter
|
14 |
-
from yolo.utils.deploy_utils import FastModelLoader
|
15 |
-
from yolo.utils.logging_utils import ProgressLogger
|
16 |
-
from yolo.utils.model_utils import get_device
|
17 |
|
18 |
|
19 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
20 |
def main(cfg: Config):
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
solver.solve(dataloader)
|
40 |
|
41 |
|
42 |
if __name__ == "__main__":
|
|
|
2 |
from pathlib import Path
|
3 |
|
4 |
import hydra
|
5 |
+
from lightning import Trainer
|
6 |
|
7 |
project_root = Path(__file__).resolve().parent.parent
|
8 |
sys.path.append(str(project_root))
|
9 |
|
10 |
from yolo.config.config import Config
|
11 |
+
from yolo.tools.solver import TrainModel, ValidateModel
|
12 |
+
from yolo.utils.logging_utils import setup
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
|
15 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
16 |
def main(cfg: Config):
|
17 |
+
callbacks, loggers = setup(cfg)
|
18 |
+
|
19 |
+
trainer = Trainer(
|
20 |
+
accelerator="cuda",
|
21 |
+
max_epochs=getattr(cfg.task, "epoch", None),
|
22 |
+
precision="16-mixed",
|
23 |
+
callbacks=callbacks,
|
24 |
+
logger=loggers,
|
25 |
+
log_every_n_steps=1,
|
26 |
+
)
|
27 |
+
|
28 |
+
match cfg.task.task:
|
29 |
+
case "train":
|
30 |
+
model = TrainModel(cfg)
|
31 |
+
trainer.fit(model)
|
32 |
+
case "validation":
|
33 |
+
model = ValidateModel(cfg)
|
34 |
+
trainer.validate(model)
|
|
|
35 |
|
36 |
|
37 |
if __name__ == "__main__":
|
yolo/tools/solver.py
CHANGED
@@ -1,267 +1,89 @@
|
|
1 |
-
import
|
2 |
-
import
|
3 |
-
import json
|
4 |
-
import os
|
5 |
-
import time
|
6 |
-
from collections import defaultdict
|
7 |
-
from pathlib import Path
|
8 |
-
from typing import Dict, Optional
|
9 |
|
10 |
-
import
|
11 |
-
from
|
12 |
-
from
|
13 |
-
from torch.cuda.amp import GradScaler, autocast
|
14 |
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
15 |
-
from torch.utils.data import DataLoader
|
16 |
-
|
17 |
-
from yolo.config.config import Config, DatasetConfig, TrainConfig, ValidationConfig
|
18 |
-
from yolo.model.yolo import YOLO
|
19 |
-
from yolo.tools.data_loader import StreamDataLoader, create_dataloader
|
20 |
-
from yolo.tools.drawer import draw_bboxes, draw_model
|
21 |
from yolo.tools.loss_functions import create_loss_function
|
22 |
-
from yolo.utils.bounding_box_utils import
|
23 |
-
from yolo.utils.
|
24 |
-
from yolo.utils.logger import logger
|
25 |
-
from yolo.utils.logging_utils import ProgressLogger, log_model_structure
|
26 |
-
from yolo.utils.model_utils import (
|
27 |
-
ExponentialMovingAverage,
|
28 |
-
PostProccess,
|
29 |
-
collect_prediction,
|
30 |
-
create_optimizer,
|
31 |
-
create_scheduler,
|
32 |
-
predicts_to_json,
|
33 |
-
)
|
34 |
-
from yolo.utils.solver_utils import calculate_ap
|
35 |
-
|
36 |
|
37 |
-
class ModelTrainer:
|
38 |
-
def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device, use_ddp: bool):
|
39 |
-
train_cfg: TrainConfig = cfg.task
|
40 |
-
self.model = model if not use_ddp else DDP(model, device_ids=[device])
|
41 |
-
self.use_ddp = use_ddp
|
42 |
-
self.vec2box = vec2box
|
43 |
-
self.device = device
|
44 |
-
self.optimizer = create_optimizer(model, train_cfg.optimizer)
|
45 |
-
self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
|
46 |
-
self.loss_fn = create_loss_function(cfg, vec2box)
|
47 |
-
self.progress = progress
|
48 |
-
self.num_epochs = cfg.task.epoch
|
49 |
-
self.mAPs_dict = defaultdict(list)
|
50 |
|
51 |
-
|
52 |
-
|
|
|
|
|
53 |
|
54 |
-
|
55 |
-
|
56 |
-
draw_model(model=model)
|
57 |
|
58 |
-
self.validation_dataloader = create_dataloader(
|
59 |
-
cfg.task.validation.data, cfg.dataset, cfg.task.validation.task, use_ddp
|
60 |
-
)
|
61 |
-
self.validator = ModelValidator(cfg.task.validation, cfg.dataset, model, vec2box, progress, device)
|
62 |
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
65 |
else:
|
66 |
-
self.
|
67 |
-
self.
|
68 |
-
|
69 |
-
def train_one_batch(self, images: Tensor, targets: Tensor):
|
70 |
-
images, targets = images.to(self.device), targets.to(self.device)
|
71 |
-
self.optimizer.zero_grad()
|
72 |
-
|
73 |
-
with autocast():
|
74 |
-
predicts = self.model(images)
|
75 |
-
aux_predicts = self.vec2box(predicts["AUX"])
|
76 |
-
main_predicts = self.vec2box(predicts["Main"])
|
77 |
-
loss, loss_item = self.loss_fn(aux_predicts, main_predicts, targets)
|
78 |
-
|
79 |
-
self.scaler.scale(loss).backward()
|
80 |
-
self.scaler.unscale_(self.optimizer)
|
81 |
-
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
|
82 |
-
self.scaler.step(self.optimizer)
|
83 |
-
self.scaler.update()
|
84 |
-
|
85 |
-
return loss_item
|
86 |
-
|
87 |
-
def train_one_epoch(self, dataloader):
|
88 |
-
self.model.train()
|
89 |
-
total_loss = defaultdict(float)
|
90 |
-
total_samples = 0
|
91 |
-
self.optimizer.next_epoch(len(dataloader))
|
92 |
-
for batch_size, images, targets, *_ in dataloader:
|
93 |
-
self.optimizer.next_batch()
|
94 |
-
loss_each = self.train_one_batch(images, targets)
|
95 |
-
|
96 |
-
for loss_name, loss_val in loss_each.items():
|
97 |
-
if self.use_ddp: # collecting loss for each batch
|
98 |
-
distributed.all_reduce(loss_val, op=distributed.ReduceOp.AVG)
|
99 |
-
total_loss[loss_name] += loss_val.item() * batch_size
|
100 |
-
total_samples += batch_size
|
101 |
-
self.progress.one_batch(loss_each)
|
102 |
-
|
103 |
-
for loss_val in total_loss.values():
|
104 |
-
loss_val /= total_samples
|
105 |
-
|
106 |
-
if self.scheduler:
|
107 |
-
self.scheduler.step()
|
108 |
-
|
109 |
-
return total_loss
|
110 |
-
|
111 |
-
def save_checkpoint(self, epoch_idx: int, file_name: Optional[str] = None):
|
112 |
-
file_name = file_name or f"E{epoch_idx:03d}.pt"
|
113 |
-
file_path = self.weights_dir / file_name
|
114 |
-
|
115 |
-
checkpoint = {
|
116 |
-
"epoch": epoch_idx,
|
117 |
-
"model_state_dict": self.model.state_dict(),
|
118 |
-
"optimizer_state_dict": self.optimizer.state_dict(),
|
119 |
-
}
|
120 |
-
if self.ema:
|
121 |
-
self.ema.apply_shadow()
|
122 |
-
checkpoint["model_state_dict_ema"] = self.model.state_dict()
|
123 |
-
self.ema.restore()
|
124 |
-
|
125 |
-
logger.info(f"💾 success save at {file_path}")
|
126 |
-
torch.save(checkpoint, file_path)
|
127 |
-
|
128 |
-
def good_epoch(self, mAPs: Dict[str, Tensor]) -> bool:
|
129 |
-
save_flag = True
|
130 |
-
for mAP_key, mAP_val in mAPs.items():
|
131 |
-
self.mAPs_dict[mAP_key].append(mAP_val)
|
132 |
-
if mAP_val < max(self.mAPs_dict[mAP_key]):
|
133 |
-
save_flag = False
|
134 |
-
return save_flag
|
135 |
-
|
136 |
-
def solve(self, dataloader: DataLoader):
|
137 |
-
logger.info("🚄 Start Training!")
|
138 |
-
num_epochs = self.num_epochs
|
139 |
-
|
140 |
-
self.progress.start_train(num_epochs)
|
141 |
-
for epoch_idx in range(num_epochs):
|
142 |
-
if self.use_ddp:
|
143 |
-
dataloader.sampler.set_epoch(epoch_idx)
|
144 |
|
145 |
-
|
146 |
-
|
147 |
-
self.
|
148 |
-
|
149 |
-
|
150 |
-
if mAPs is not None and self.good_epoch(mAPs):
|
151 |
-
self.save_checkpoint(epoch_idx=epoch_idx)
|
152 |
-
# TODO: save model if result are better than before
|
153 |
-
self.progress.finish_train()
|
154 |
-
|
155 |
-
|
156 |
-
class ModelTester:
|
157 |
-
def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device):
|
158 |
-
self.model = model
|
159 |
-
self.device = device
|
160 |
-
self.progress = progress
|
161 |
-
|
162 |
-
self.post_proccess = PostProccess(vec2box, cfg.task.nms)
|
163 |
-
self.save_path = progress.save_path / "images"
|
164 |
-
os.makedirs(self.save_path, exist_ok=True)
|
165 |
-
self.save_predict = getattr(cfg.task, "save_predict", None)
|
166 |
-
self.idx2label = cfg.dataset.class_list
|
167 |
-
|
168 |
-
def solve(self, dataloader: StreamDataLoader):
|
169 |
-
logger.info("👀 Start Inference!")
|
170 |
-
if isinstance(self.model, torch.nn.Module):
|
171 |
-
self.model.eval()
|
172 |
-
|
173 |
-
if dataloader.is_stream:
|
174 |
-
import cv2
|
175 |
-
import numpy as np
|
176 |
-
|
177 |
-
last_time = time.time()
|
178 |
-
try:
|
179 |
-
for idx, (images, rev_tensor, origin_frame) in enumerate(dataloader):
|
180 |
-
images = images.to(self.device)
|
181 |
-
rev_tensor = rev_tensor.to(self.device)
|
182 |
-
with torch.no_grad():
|
183 |
-
predicts = self.model(images)
|
184 |
-
predicts = self.post_proccess(predicts, rev_tensor)
|
185 |
-
img = draw_bboxes(origin_frame, predicts, idx2label=self.idx2label)
|
186 |
-
|
187 |
-
if dataloader.is_stream:
|
188 |
-
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
189 |
-
fps = 1 / (time.time() - last_time)
|
190 |
-
cv2.putText(img, f"FPS: {fps:.2f}", (0, 15), 0, 0.5, (100, 255, 0), 1, cv2.LINE_AA)
|
191 |
-
last_time = time.time()
|
192 |
-
cv2.imshow("Prediction", img)
|
193 |
-
if cv2.waitKey(1) & 0xFF == ord("q"):
|
194 |
-
break
|
195 |
-
if not self.save_predict:
|
196 |
-
continue
|
197 |
-
if self.save_predict != False:
|
198 |
-
save_image_path = self.save_path / f"frame{idx:03d}.png"
|
199 |
-
img.save(save_image_path)
|
200 |
-
logger.info(f"💾 Saved visualize image at {save_image_path}")
|
201 |
-
|
202 |
-
except (KeyboardInterrupt, Exception) as e:
|
203 |
-
dataloader.stop_event.set()
|
204 |
-
dataloader.stop()
|
205 |
-
if isinstance(e, KeyboardInterrupt):
|
206 |
-
logger.error("User Keyboard Interrupt")
|
207 |
-
else:
|
208 |
-
raise e
|
209 |
-
dataloader.stop()
|
210 |
-
|
211 |
-
|
212 |
-
class ModelValidator:
|
213 |
-
def __init__(
|
214 |
-
self,
|
215 |
-
validation_cfg: ValidationConfig,
|
216 |
-
dataset_cfg: DatasetConfig,
|
217 |
-
model: YOLO,
|
218 |
-
vec2box: Vec2Box,
|
219 |
-
progress: ProgressLogger,
|
220 |
-
device,
|
221 |
-
):
|
222 |
-
self.model = model
|
223 |
-
self.device = device
|
224 |
-
self.progress = progress
|
225 |
-
|
226 |
-
self.post_proccess = PostProccess(vec2box, validation_cfg.nms)
|
227 |
-
self.json_path = self.progress.save_path / "predict.json"
|
228 |
-
|
229 |
-
with contextlib.redirect_stdout(io.StringIO()):
|
230 |
-
# TODO: load with config file
|
231 |
-
json_path, _ = locate_label_paths(Path(dataset_cfg.path), dataset_cfg.get("validation", "val"))
|
232 |
-
if json_path:
|
233 |
-
self.coco_gt = COCO(json_path)
|
234 |
-
|
235 |
-
def solve(self, dataloader, epoch_idx=1):
|
236 |
-
# logger.info("🧪 Start Validation!")
|
237 |
-
self.model.eval()
|
238 |
-
predict_json, mAPs = [], defaultdict(list)
|
239 |
-
self.progress.start_one_epoch(len(dataloader), task="Validate")
|
240 |
-
for batch_size, images, targets, rev_tensor, img_paths in dataloader:
|
241 |
-
images, targets, rev_tensor = images.to(self.device), targets.to(self.device), rev_tensor.to(self.device)
|
242 |
-
with torch.no_grad():
|
243 |
-
predicts = self.model(images)
|
244 |
-
predicts = self.post_proccess(predicts)
|
245 |
-
for idx, predict in enumerate(predicts):
|
246 |
-
mAP = calculate_map(predict, targets[idx])
|
247 |
-
for mAP_key, mAP_val in mAP.items():
|
248 |
-
mAPs[mAP_key].append(mAP_val)
|
249 |
-
|
250 |
-
avg_mAPs = {key: 100 * torch.mean(torch.stack(val)) for key, val in mAPs.items()}
|
251 |
-
self.progress.one_batch(avg_mAPs)
|
252 |
|
253 |
-
|
254 |
-
self.
|
255 |
-
self.progress.visualize_image(images, targets, predicts, epoch_idx=epoch_idx)
|
256 |
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
self.progress.start_pycocotools()
|
264 |
-
result = calculate_ap(self.coco_gt, predict_json)
|
265 |
-
self.progress.finish_pycocotools(result, epoch_idx)
|
266 |
|
267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from lightning import LightningModule
|
2 |
+
from torchmetrics.detection import MeanAveragePrecision
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
+
from yolo.config.config import Config
|
5 |
+
from yolo.model.yolo import create_model
|
6 |
+
from yolo.tools.data_loader import create_dataloader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
from yolo.tools.loss_functions import create_loss_function
|
8 |
+
from yolo.utils.bounding_box_utils import create_converter, to_metrics_format
|
9 |
+
from yolo.utils.model_utils import PostProccess, create_optimizer, create_scheduler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
+
class BaseModel(LightningModule):
|
13 |
+
def __init__(self, cfg: Config):
|
14 |
+
super().__init__()
|
15 |
+
self.model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight)
|
16 |
|
17 |
+
def forward(self, x):
|
18 |
+
return self.model(x)
|
|
|
19 |
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
class ValidateModel(BaseModel):
|
22 |
+
def __init__(self, cfg: Config):
|
23 |
+
super().__init__(cfg)
|
24 |
+
self.cfg = cfg
|
25 |
+
if self.cfg.task.task == "validation":
|
26 |
+
self.validation_cfg = self.cfg.task
|
27 |
else:
|
28 |
+
self.validation_cfg = self.cfg.task.validation
|
29 |
+
self.metric = MeanAveragePrecision(iou_type="bbox", box_format="xyxy")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
+
def setup(self, stage):
|
32 |
+
self.vec2box = create_converter(
|
33 |
+
self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device
|
34 |
+
)
|
35 |
+
self.post_proccess = PostProccess(self.vec2box, self.validation_cfg.nms)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
+
def val_dataloader(self):
|
38 |
+
return create_dataloader(self.validation_cfg.data, self.cfg.dataset, self.validation_cfg.task)
|
|
|
39 |
|
40 |
+
def validation_step(self, batch, batch_idx):
|
41 |
+
batch_size, images, targets, rev_tensor, img_paths = batch
|
42 |
+
predicts = self.post_proccess(self(images))
|
43 |
+
batch_metrics = self.metric(
|
44 |
+
[to_metrics_format(predict) for predict in predicts], [to_metrics_format(target) for target in targets]
|
45 |
+
)
|
|
|
|
|
|
|
46 |
|
47 |
+
self.log_dict(
|
48 |
+
{
|
49 |
+
"map": batch_metrics["map"],
|
50 |
+
"map_50": batch_metrics["map_50"],
|
51 |
+
},
|
52 |
+
on_step=True,
|
53 |
+
prog_bar=True,
|
54 |
+
logger=False,
|
55 |
+
batch_size=batch_size,
|
56 |
+
)
|
57 |
+
return predicts
|
58 |
+
|
59 |
+
def on_validation_epoch_end(self):
|
60 |
+
epoch_metrics = self.metric.compute()
|
61 |
+
del epoch_metrics["classes"]
|
62 |
+
self.log_dict(epoch_metrics, on_epoch=True, prog_bar=True, logger=True)
|
63 |
+
|
64 |
+
|
65 |
+
class TrainModel(ValidateModel):
|
66 |
+
def __init__(self, cfg: Config):
|
67 |
+
super().__init__(cfg)
|
68 |
+
self.cfg = cfg
|
69 |
+
|
70 |
+
def setup(self, stage):
|
71 |
+
super().setup(stage)
|
72 |
+
self.loss_fn = create_loss_function(self.cfg, self.vec2box)
|
73 |
+
|
74 |
+
def train_dataloader(self):
|
75 |
+
return create_dataloader(self.cfg.task.data, self.cfg.dataset, self.cfg.task.task)
|
76 |
+
|
77 |
+
def training_step(self, batch, batch_idx):
|
78 |
+
batch_size, images, targets, *_ = batch
|
79 |
+
predicts = self(images)
|
80 |
+
aux_predicts = self.vec2box(predicts["AUX"])
|
81 |
+
main_predicts = self.vec2box(predicts["Main"])
|
82 |
+
loss, loss_item = self.loss_fn(aux_predicts, main_predicts, targets)
|
83 |
+
self.log_dict(loss_item, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
|
84 |
+
return loss * batch_size
|
85 |
+
|
86 |
+
def configure_optimizers(self):
|
87 |
+
optimizer = create_optimizer(self.model, self.cfg.task.optimizer)
|
88 |
+
scheduler = create_scheduler(optimizer, self.cfg.task.scheduler)
|
89 |
+
return [optimizer], [scheduler]
|
yolo/utils/bounding_box_utils.py
CHANGED
@@ -446,3 +446,10 @@ def calculate_map(predictions, ground_truths, iou_thresholds=arange(0.5, 1, 0.05
|
|
446 |
"mAP.5:.95": torch.mean(torch.stack(aps)),
|
447 |
}
|
448 |
return mAP
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
446 |
"mAP.5:.95": torch.mean(torch.stack(aps)),
|
447 |
}
|
448 |
return mAP
|
449 |
+
|
450 |
+
|
451 |
+
def to_metrics_format(prediction: Tensor) -> Dict[str, Union[float, Tensor]]:
|
452 |
+
bbox = {"boxes": prediction[:, 1:5], "labels": prediction[:, 0].int()}
|
453 |
+
if prediction.size(1) == 6:
|
454 |
+
bbox["scores"] = prediction[:, 5]
|
455 |
+
return bbox
|
yolo/utils/logging_utils.py
CHANGED
@@ -11,9 +11,7 @@ Example:
|
|
11 |
custom_logger()
|
12 |
"""
|
13 |
|
14 |
-
import
|
15 |
-
import random
|
16 |
-
import sys
|
17 |
from collections import deque
|
18 |
from logging import FileHandler
|
19 |
from pathlib import Path
|
@@ -22,39 +20,29 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
22 |
import numpy as np
|
23 |
import torch
|
24 |
import wandb
|
25 |
-
import
|
|
|
|
|
|
|
26 |
from omegaconf import ListConfig
|
|
|
27 |
from rich.console import Console, Group
|
28 |
-
from rich.
|
29 |
-
BarColumn,
|
30 |
-
Progress,
|
31 |
-
SpinnerColumn,
|
32 |
-
TextColumn,
|
33 |
-
TimeRemainingColumn,
|
34 |
-
)
|
35 |
from rich.table import Table
|
|
|
36 |
from torch import Tensor
|
37 |
from torch.nn import ModuleList
|
38 |
-
from
|
39 |
-
from torchvision.transforms.functional import pil_to_tensor
|
40 |
|
41 |
from yolo.config.config import Config, YOLOLayer
|
42 |
from yolo.model.yolo import YOLO
|
43 |
-
from yolo.tools.drawer import draw_bboxes
|
44 |
from yolo.utils.logger import logger
|
45 |
from yolo.utils.solver_utils import make_ap_table
|
46 |
|
47 |
|
48 |
-
def custom_logger(quite: bool = False):
|
49 |
-
if quite:
|
50 |
-
logger.removeHandler("YOLO_logger")
|
51 |
-
|
52 |
-
|
53 |
# TODO: should be moved to correct position
|
54 |
def set_seed(seed):
|
55 |
-
|
56 |
-
np.random.seed(seed)
|
57 |
-
torch.manual_seed(seed)
|
58 |
if torch.cuda.is_available():
|
59 |
torch.cuda.manual_seed(seed)
|
60 |
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
|
@@ -62,189 +50,211 @@ def set_seed(seed):
|
|
62 |
torch.backends.cudnn.benchmark = False
|
63 |
|
64 |
|
65 |
-
class
|
66 |
-
def
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
self.save_path = validate_log_directory(cfg, exp_name=cfg.name)
|
72 |
-
|
73 |
-
progress_bar = (
|
74 |
-
SpinnerColumn(),
|
75 |
-
TextColumn("[progress.description]{task.description}"),
|
76 |
-
BarColumn(bar_width=None),
|
77 |
-
TextColumn("{task.completed:.0f}/{task.total:.0f}"),
|
78 |
-
TimeRemainingColumn(),
|
79 |
-
)
|
80 |
-
self.ap_table = Table()
|
81 |
-
# TODO: load maxlen by config files
|
82 |
-
self.ap_past_list = deque(maxlen=5)
|
83 |
-
self.last_result = 0
|
84 |
-
super().__init__(*args, *progress_bar, **kwargs)
|
85 |
-
|
86 |
-
self.use_wandb = cfg.use_wandb
|
87 |
-
if self.use_wandb and self.local_rank == 0:
|
88 |
-
wandb.errors.term._log = custom_wandb_log
|
89 |
-
self.wandb = wandb.init(
|
90 |
-
project="YOLO", resume="allow", mode="online", dir=self.save_path, id=None, name=exp_name
|
91 |
-
)
|
92 |
|
93 |
-
self.use_tensorboard = cfg.use_tensorboard
|
94 |
-
if self.use_tensorboard and self.local_rank == 0:
|
95 |
-
from torch.utils.tensorboard import SummaryWriter
|
96 |
|
97 |
-
|
98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
-
|
101 |
-
def wrapper(self, *args, **kwargs):
|
102 |
-
if getattr(self, "local_rank", 0) != 0:
|
103 |
-
return
|
104 |
-
return logging_function(self, *args, **kwargs)
|
105 |
|
106 |
-
|
|
|
|
|
107 |
|
108 |
-
|
109 |
-
|
110 |
-
return
|
111 |
|
112 |
-
@
|
113 |
-
def
|
114 |
-
self.
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
self.
|
122 |
-
self.
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
if
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
self.
|
138 |
-
|
139 |
-
@rank_check
|
140 |
-
def one_batch(self, batch_info: Dict[str, Tensor] = None):
|
141 |
-
epoch_descript = "[cyan]" + self.task + "[white] |"
|
142 |
-
batch_descript = "|"
|
143 |
-
if self.task == "Train":
|
144 |
-
self.update(self.task_epoch, advance=1 / self.num_batches)
|
145 |
-
for info_name, info_val in batch_info.items():
|
146 |
-
epoch_descript += f"{info_name: ^9}|"
|
147 |
-
batch_descript += f" {info_val:2.2f} |"
|
148 |
-
self.update(self.batch_task, advance=1, description=f"[green]{self.task} [white]{batch_descript}")
|
149 |
-
if hasattr(self, "task_epoch"):
|
150 |
-
self.update(self.task_epoch, description=epoch_descript)
|
151 |
-
|
152 |
-
@rank_check
|
153 |
-
def finish_one_epoch(self, batch_info: Dict[str, Any] = None, epoch_idx: int = -1):
|
154 |
-
if self.task == "Train":
|
155 |
-
prefix = "Loss"
|
156 |
-
elif self.task == "Validate":
|
157 |
-
prefix = "Metrics"
|
158 |
-
batch_info = {f"{prefix}/{key}": value for key, value in batch_info.items()}
|
159 |
-
if self.use_wandb:
|
160 |
-
self.wandb.log(batch_info, step=epoch_idx)
|
161 |
-
if self.use_tensorboard:
|
162 |
-
for key, value in batch_info.items():
|
163 |
-
self.tb_writer.add_scalar(key, value, epoch_idx)
|
164 |
-
|
165 |
-
self.remove_task(self.batch_task)
|
166 |
-
|
167 |
-
@rank_check
|
168 |
-
def visualize_image(
|
169 |
-
self,
|
170 |
-
images: Optional[Tensor] = None,
|
171 |
-
ground_truth: Optional[Tensor] = None,
|
172 |
-
prediction: Optional[Union[List[Tensor], Tensor]] = None,
|
173 |
-
epoch_idx: int = 0,
|
174 |
-
) -> None:
|
175 |
-
"""
|
176 |
-
Upload the ground truth bounding boxes, predicted bounding boxes, and the original image to wandb or TensorBoard.
|
177 |
-
|
178 |
-
Args:
|
179 |
-
images (Optional[Tensor]): Tensor of images with shape (BZ, 3, 640, 640).
|
180 |
-
ground_truth (Optional[Tensor]): Ground truth bounding boxes with shape (BZ, N, 5) or (N, 5). Defaults to None.
|
181 |
-
prediction (prediction: Optional[Union[List[Tensor], Tensor]]): List of predicted bounding boxes with shape (N, 6) or (N, 6). Defaults to None.
|
182 |
-
epoch_idx (int): Current epoch index. Defaults to 0.
|
183 |
-
"""
|
184 |
-
if images is not None:
|
185 |
-
images = images[0] if images.ndim == 4 else images
|
186 |
-
if self.use_wandb:
|
187 |
-
wandb.log({"Input Image": wandb.Image(images)}, step=epoch_idx)
|
188 |
-
if self.use_tensorboard:
|
189 |
-
self.tb_writer.add_image("Media/Input Image", images, 1)
|
190 |
-
|
191 |
-
if ground_truth is not None:
|
192 |
-
gt_boxes = ground_truth[0] if ground_truth.ndim == 3 else ground_truth
|
193 |
-
if self.use_wandb:
|
194 |
-
wandb.log(
|
195 |
-
{"Ground Truth": wandb.Image(images, boxes={"predictions": {"box_data": log_bbox(gt_boxes)}})},
|
196 |
-
step=epoch_idx,
|
197 |
-
)
|
198 |
-
if self.use_tensorboard:
|
199 |
-
self.tb_writer.add_image("Media/Ground Truth", pil_to_tensor(draw_bboxes(images, gt_boxes)), epoch_idx)
|
200 |
-
|
201 |
-
if prediction is not None:
|
202 |
-
pred_boxes = prediction[0] if isinstance(prediction, list) else prediction
|
203 |
-
if self.use_wandb:
|
204 |
-
wandb.log(
|
205 |
-
{"Prediction": wandb.Image(images, boxes={"predictions": {"box_data": log_bbox(pred_boxes)}})},
|
206 |
-
step=epoch_idx,
|
207 |
-
)
|
208 |
-
if self.use_tensorboard:
|
209 |
-
self.tb_writer.add_image("Media/Prediction", pil_to_tensor(draw_bboxes(images, pred_boxes)), epoch_idx)
|
210 |
-
|
211 |
-
@rank_check
|
212 |
-
def start_pycocotools(self):
|
213 |
-
self.batch_task = self.add_task("[green]Run pycocotools", total=1)
|
214 |
-
|
215 |
-
@rank_check
|
216 |
-
def finish_pycocotools(self, result, epoch_idx=-1):
|
217 |
-
ap_table, ap_main = make_ap_table(result * 100, self.ap_past_list, self.last_result, epoch_idx)
|
218 |
-
self.last_result = np.maximum(result, self.last_result)
|
219 |
-
self.ap_past_list.append((epoch_idx, ap_main))
|
220 |
-
self.ap_table = ap_table
|
221 |
-
|
222 |
-
if self.use_wandb:
|
223 |
-
self.wandb.log({"PyCOCO/AP @ .5:.95": ap_main[2], "PyCOCO/AP @ .5": ap_main[5]})
|
224 |
-
if self.use_tensorboard:
|
225 |
-
# TODO: waiting torch bugs fix, https://github.com/pytorch/pytorch/issues/32651
|
226 |
-
self.tb_writer.add_scalar("PyCOCO/AP @ .5:.95", ap_main[2], epoch_idx)
|
227 |
-
self.tb_writer.add_scalar("PyCOCO/AP @ .5", ap_main[5], epoch_idx)
|
228 |
-
|
229 |
-
self.update(self.batch_task, advance=1)
|
230 |
self.refresh()
|
231 |
-
self.remove_task(self.batch_task)
|
232 |
|
233 |
-
@
|
234 |
-
def
|
235 |
-
self.
|
236 |
-
self.
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
|
|
|
|
|
|
|
|
|
|
|
242 |
|
243 |
-
def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=True, silent=False):
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
|
249 |
|
250 |
def log_model_structure(model: Union[ModuleList, YOLOLayer, YOLO]):
|
@@ -291,7 +301,7 @@ def validate_log_directory(cfg: Config, exp_name: str) -> Path:
|
|
291 |
)
|
292 |
|
293 |
save_path.mkdir(parents=True, exist_ok=True)
|
294 |
-
logger.info(f"📄 Created log folder: [
|
295 |
logger.addHandler(FileHandler(save_path / "output.log"))
|
296 |
return save_path
|
297 |
|
@@ -327,4 +337,4 @@ def log_bbox(
|
|
327 |
bbox_entry["scores"] = {"confidence": conf[0]}
|
328 |
bbox_list.append(bbox_entry)
|
329 |
|
330 |
-
return bbox_list
|
|
|
11 |
custom_logger()
|
12 |
"""
|
13 |
|
14 |
+
import logging
|
|
|
|
|
15 |
from collections import deque
|
16 |
from logging import FileHandler
|
17 |
from pathlib import Path
|
|
|
20 |
import numpy as np
|
21 |
import torch
|
22 |
import wandb
|
23 |
+
from lightning import LightningModule, Trainer, seed_everything
|
24 |
+
from lightning.pytorch.callbacks import Callback, RichModelSummary, RichProgressBar
|
25 |
+
from lightning.pytorch.callbacks.progress.rich_progress import CustomProgress
|
26 |
+
from lightning.pytorch.loggers import WandbLogger
|
27 |
from omegaconf import ListConfig
|
28 |
+
from rich import reconfigure
|
29 |
from rich.console import Console, Group
|
30 |
+
from rich.logging import RichHandler
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
from rich.table import Table
|
32 |
+
from rich.text import Text
|
33 |
from torch import Tensor
|
34 |
from torch.nn import ModuleList
|
35 |
+
from typing_extensions import override
|
|
|
36 |
|
37 |
from yolo.config.config import Config, YOLOLayer
|
38 |
from yolo.model.yolo import YOLO
|
|
|
39 |
from yolo.utils.logger import logger
|
40 |
from yolo.utils.solver_utils import make_ap_table
|
41 |
|
42 |
|
|
|
|
|
|
|
|
|
|
|
43 |
# TODO: should be moved to correct position
|
44 |
def set_seed(seed):
|
45 |
+
seed_everything(seed)
|
|
|
|
|
46 |
if torch.cuda.is_available():
|
47 |
torch.cuda.manual_seed(seed)
|
48 |
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
|
|
|
50 |
torch.backends.cudnn.benchmark = False
|
51 |
|
52 |
|
53 |
+
class YOLOCustomProgress(CustomProgress):
|
54 |
+
def get_renderable(self):
|
55 |
+
renderable = Group(*self.get_renderables())
|
56 |
+
if hasattr(self, "table"):
|
57 |
+
renderable = Group(*self.get_renderables(), self.table)
|
58 |
+
return renderable
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
|
|
|
|
|
|
60 |
|
61 |
+
class YOLORichProgressBar(RichProgressBar):
|
62 |
+
@override
|
63 |
+
def _init_progress(self, trainer: "Trainer") -> None:
|
64 |
+
if self.is_enabled and (self.progress is None or self._progress_stopped):
|
65 |
+
self._reset_progress_bar_ids()
|
66 |
+
reconfigure(**self._console_kwargs)
|
67 |
+
self._console = Console()
|
68 |
+
self._console.clear_live()
|
69 |
+
self.progress = YOLOCustomProgress(
|
70 |
+
*self.configure_columns(trainer),
|
71 |
+
auto_refresh=False,
|
72 |
+
disable=self.is_disabled,
|
73 |
+
console=self._console,
|
74 |
+
)
|
75 |
+
self.progress.start()
|
76 |
|
77 |
+
self._progress_stopped = False
|
|
|
|
|
|
|
|
|
78 |
|
79 |
+
self.max_result = 0
|
80 |
+
self.past_results = deque(maxlen=5)
|
81 |
+
self.progress.table = Table()
|
82 |
|
83 |
+
@override
|
84 |
+
def _get_train_description(self, current_epoch: int) -> str:
|
85 |
+
return Text("[cyan]Train [white]|")
|
86 |
|
87 |
+
@override
|
88 |
+
def on_train_start(self, trainer, pl_module):
|
89 |
+
self._init_progress(trainer)
|
90 |
+
num_epochs = trainer.max_epochs - 1
|
91 |
+
self.task_epoch = self._add_task(
|
92 |
+
total_batches=num_epochs,
|
93 |
+
description=f"[cyan]Start Training {num_epochs} epochs",
|
94 |
+
)
|
95 |
+
self.max_result = 0
|
96 |
+
self.past_results.clear()
|
97 |
+
self.progress.update(self.task_epoch, advance=-0.5)
|
98 |
+
|
99 |
+
@override
|
100 |
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch: Any, batch_idx: int):
|
101 |
+
self._update(self.train_progress_bar_id, batch_idx + 1)
|
102 |
+
self._update_metrics(trainer, pl_module)
|
103 |
+
epoch_descript = "[cyan]Train [white]|"
|
104 |
+
batch_descript = "[green]Train [white]|"
|
105 |
+
metrics = self.get_metrics(trainer, pl_module)
|
106 |
+
metrics.pop("v_num")
|
107 |
+
for metrics_name, metrics_val in metrics.items():
|
108 |
+
if "Loss_step" in metrics_name:
|
109 |
+
epoch_descript += f"{metrics_name.removesuffix('_step'): ^9}|"
|
110 |
+
batch_descript += f" {metrics_val:2.2f} |"
|
111 |
+
|
112 |
+
self.progress.update(self.task_epoch, advance=1 / self.total_train_batches, description=epoch_descript)
|
113 |
+
self.progress.update(self.train_progress_bar_id, description=batch_descript)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
self.refresh()
|
|
|
115 |
|
116 |
+
@override
|
117 |
+
def on_train_end(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
|
118 |
+
self._update_metrics(trainer, pl_module)
|
119 |
+
self.progress.remove_task(self.train_progress_bar_id)
|
120 |
+
self.train_progress_bar_id = None
|
121 |
+
|
122 |
+
@override
|
123 |
+
def on_validation_end(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
|
124 |
+
if trainer.state.fn == "fit":
|
125 |
+
self._update_metrics(trainer, pl_module)
|
126 |
+
self.reset_dataloader_idx_tracker()
|
127 |
+
all_metrics = self.get_metrics(trainer, pl_module)
|
128 |
+
|
129 |
+
ap_ar_list = [
|
130 |
+
key
|
131 |
+
for key in all_metrics.keys()
|
132 |
+
if key.startswith(("map", "mar")) and not key.endswith(("_step", "_epoch"))
|
133 |
+
]
|
134 |
+
score = np.array([all_metrics[key] for key in ap_ar_list]) * 100
|
135 |
+
|
136 |
+
self.progress.table, ap_main = make_ap_table(score, self.past_results, self.max_result, trainer.current_epoch)
|
137 |
+
self.max_result = np.maximum(score, self.max_result)
|
138 |
+
self.past_results.append((trainer.current_epoch, ap_main))
|
139 |
+
|
140 |
+
@override
|
141 |
+
def refresh(self) -> None:
|
142 |
+
if self.progress:
|
143 |
+
self.progress.refresh()
|
144 |
+
|
145 |
+
@property
|
146 |
+
def validation_description(self) -> str:
|
147 |
+
return "[green]Validation"
|
148 |
+
|
149 |
+
|
150 |
+
class YOLORichModelSummary(RichModelSummary):
|
151 |
+
|
152 |
+
from typing_extensions import override
|
153 |
+
|
154 |
+
@staticmethod
|
155 |
+
@override
|
156 |
+
def summarize(
|
157 |
+
summary_data: List[Tuple[str, List[str]]],
|
158 |
+
total_parameters: int,
|
159 |
+
trainable_parameters: int,
|
160 |
+
model_size: float,
|
161 |
+
total_training_modes: Dict[str, int],
|
162 |
+
**summarize_kwargs: Any,
|
163 |
+
) -> None:
|
164 |
+
from lightning.pytorch.utilities.model_summary import get_human_readable_count
|
165 |
+
from rich import get_console
|
166 |
+
from rich.table import Table
|
167 |
+
|
168 |
+
console = get_console()
|
169 |
+
|
170 |
+
header_style: str = summarize_kwargs.get("header_style", "bold magenta")
|
171 |
+
table = Table(header_style=header_style)
|
172 |
+
table.add_column(" ", style="dim")
|
173 |
+
table.add_column("Name", justify="left", no_wrap=True)
|
174 |
+
table.add_column("Type")
|
175 |
+
table.add_column("Params", justify="right")
|
176 |
+
table.add_column("Mode")
|
177 |
+
|
178 |
+
column_names = list(zip(*summary_data))[0]
|
179 |
+
|
180 |
+
for column_name in ["In sizes", "Out sizes"]:
|
181 |
+
if column_name in column_names:
|
182 |
+
table.add_column(column_name, justify="right", style="white")
|
183 |
+
|
184 |
+
rows = list(zip(*(arr[1] for arr in summary_data)))
|
185 |
+
for row in rows:
|
186 |
+
table.add_row(*row)
|
187 |
+
|
188 |
+
console.print(table)
|
189 |
+
|
190 |
+
parameters = []
|
191 |
+
for param in [trainable_parameters, total_parameters - trainable_parameters, total_parameters, model_size]:
|
192 |
+
parameters.append("{:<{}}".format(get_human_readable_count(int(param)), 10))
|
193 |
+
|
194 |
+
grid = Table(header_style=header_style)
|
195 |
+
table.add_column(" ", style="dim")
|
196 |
+
grid.add_column("[bold]Attributes[/]")
|
197 |
+
grid.add_column("Value")
|
198 |
+
|
199 |
+
grid.add_row("[bold]Trainable params[/]", f"{parameters[0]}")
|
200 |
+
grid.add_row("[bold]Non-trainable params[/]", f"{parameters[1]}")
|
201 |
+
grid.add_row("[bold]Total params[/]", f"{parameters[2]}")
|
202 |
+
grid.add_row("[bold]Total estimated model params size (MB)[/]", f"{parameters[3]}")
|
203 |
+
grid.add_row("[bold]Modules in train mode[/]", f"{total_training_modes['train']}")
|
204 |
+
grid.add_row("[bold]Modules in eval mode[/]", f"{total_training_modes['eval']}")
|
205 |
+
|
206 |
+
console.print(grid)
|
207 |
+
|
208 |
+
|
209 |
+
class ImageLogger(Callback):
|
210 |
+
def on_validation_batch_end(self, trainer: Trainer, pl_module, outputs, batch, batch_idx) -> None:
|
211 |
+
if batch_idx != 0:
|
212 |
+
return
|
213 |
+
batch_size, images, targets, rev_tensor, img_paths = batch
|
214 |
+
gt_boxes = targets[0] if targets.ndim == 3 else targets
|
215 |
+
pred_boxes = outputs[0] if isinstance(outputs, list) else outputs
|
216 |
+
images = [images[0]]
|
217 |
+
step = trainer.current_epoch
|
218 |
+
for logger in trainer.loggers:
|
219 |
+
if isinstance(logger, WandbLogger):
|
220 |
+
logger.log_image("Input Image", images, step=step)
|
221 |
+
logger.log_image("Ground Truth", images, step=step, boxes=[log_bbox(gt_boxes)])
|
222 |
+
logger.log_image("Prediction", images, step=step, boxes=[log_bbox(pred_boxes)])
|
223 |
+
|
224 |
+
|
225 |
+
def setup(cfg: Config):
|
226 |
+
if hasattr(cfg, "quite"):
|
227 |
+
logger.removeHandler("YOLO_logger")
|
228 |
+
return
|
229 |
+
|
230 |
+
class EmojiFormatter(logging.Formatter):
|
231 |
+
def format(self, record):
|
232 |
+
return f":high_voltage: {super().format(record)}"
|
233 |
|
234 |
+
rich_handler = RichHandler(markup=True)
|
235 |
+
rich_handler.setFormatter(EmojiFormatter("%(message)s"))
|
236 |
+
lightning_logger = logging.getLogger("lightning.pytorch")
|
237 |
+
lightning_logger.handlers.clear()
|
238 |
+
lightning_logger.addHandler(rich_handler)
|
239 |
|
240 |
+
def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=True, silent=False):
|
241 |
+
if silent:
|
242 |
+
return
|
243 |
+
for line in string.split("\n"):
|
244 |
+
logger.info(Text.from_ansi(":globe_with_meridians: " + line))
|
245 |
+
|
246 |
+
wandb.errors.term._log = custom_wandb_log
|
247 |
+
|
248 |
+
save_path = validate_log_directory(cfg, cfg.name)
|
249 |
+
|
250 |
+
progress, loggers = [], []
|
251 |
+
progress.append(YOLORichProgressBar())
|
252 |
+
progress.append(YOLORichModelSummary())
|
253 |
+
progress.append(ImageLogger())
|
254 |
+
|
255 |
+
loggers.append(WandbLogger(project="YOLO", name=cfg.name, save_dir=save_path, id=None))
|
256 |
+
|
257 |
+
return progress, loggers
|
258 |
|
259 |
|
260 |
def log_model_structure(model: Union[ModuleList, YOLOLayer, YOLO]):
|
|
|
301 |
)
|
302 |
|
303 |
save_path.mkdir(parents=True, exist_ok=True)
|
304 |
+
logger.info(f"📄 Created log folder: [blue b u]123{save_path}[/]")
|
305 |
logger.addHandler(FileHandler(save_path / "output.log"))
|
306 |
return save_path
|
307 |
|
|
|
337 |
bbox_entry["scores"] = {"confidence": conf[0]}
|
338 |
bbox_list.append(bbox_entry)
|
339 |
|
340 |
+
return {"predictions": {"box_data": bbox_list}}
|
yolo/utils/model_utils.py
CHANGED
@@ -56,23 +56,8 @@ def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
|
|
56 |
{"params": conv_params},
|
57 |
{"params": norm_params, "weight_decay": 0},
|
58 |
]
|
59 |
-
|
60 |
-
def next_epoch(self, batch_num):
|
61 |
-
self.min_lr = self.max_lr
|
62 |
-
self.max_lr = [param["lr"] for param in self.param_groups]
|
63 |
-
self.batch_num = batch_num
|
64 |
-
self.batch_idx = 0
|
65 |
-
|
66 |
-
def next_batch(self):
|
67 |
-
self.batch_idx += 1
|
68 |
-
for lr_idx, param_group in enumerate(self.param_groups):
|
69 |
-
min_lr, max_lr = self.min_lr[lr_idx], self.max_lr[lr_idx]
|
70 |
-
param_group["lr"] = min_lr + (self.batch_idx) * (max_lr - min_lr) / self.batch_num
|
71 |
-
|
72 |
-
optimizer_class.next_batch = next_batch
|
73 |
-
optimizer_class.next_epoch = next_epoch
|
74 |
optimizer = optimizer_class(model_parameters, **optim_cfg.args)
|
75 |
-
|
76 |
return optimizer
|
77 |
|
78 |
|
@@ -168,6 +153,7 @@ def predicts_to_json(img_paths, predicts, rev_tensor):
|
|
168 |
batch_json = []
|
169 |
for img_path, bboxes, box_reverse in zip(img_paths, predicts, rev_tensor):
|
170 |
scale, shift = box_reverse.split([1, 4])
|
|
|
171 |
bboxes[:, 1:5] = (bboxes[:, 1:5] - shift[None]) / scale[None]
|
172 |
bboxes[:, 1:5] = transform_bbox(bboxes[:, 1:5], "xyxy -> xywh")
|
173 |
for cls, *pos, conf in bboxes:
|
|
|
56 |
{"params": conv_params},
|
57 |
{"params": norm_params, "weight_decay": 0},
|
58 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
optimizer = optimizer_class(model_parameters, **optim_cfg.args)
|
60 |
+
# TODO: implement batch lr schedular when warm up
|
61 |
return optimizer
|
62 |
|
63 |
|
|
|
153 |
batch_json = []
|
154 |
for img_path, bboxes, box_reverse in zip(img_paths, predicts, rev_tensor):
|
155 |
scale, shift = box_reverse.split([1, 4])
|
156 |
+
bboxes = bboxes.clone()
|
157 |
bboxes[:, 1:5] = (bboxes[:, 1:5] - shift[None]) / scale[None]
|
158 |
bboxes[:, 1:5] = transform_bbox(bboxes[:, 1:5], "xyxy -> xywh")
|
159 |
for cls, *pos, conf in bboxes:
|
yolo/utils/solver_utils.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import contextlib
|
2 |
import io
|
|
|
3 |
|
4 |
import numpy as np
|
5 |
from pycocotools.coco import COCO
|
@@ -17,7 +18,7 @@ def calculate_ap(coco_gt: COCO, pd_path):
|
|
17 |
return coco_eval.stats
|
18 |
|
19 |
|
20 |
-
def make_ap_table(score, past_result=[],
|
21 |
ap_table = Table()
|
22 |
ap_table.add_column("Epoch", justify="center", style="white", width=5)
|
23 |
ap_table.add_column("Avg. Precision", justify="left", style="cyan")
|
@@ -30,7 +31,7 @@ def make_ap_table(score, past_result=[], last_score=None, epoch=-1):
|
|
30 |
if past_result:
|
31 |
ap_table.add_row()
|
32 |
|
33 |
-
color = np.where(
|
34 |
|
35 |
this_ap = ("AP @ .5:.95", color[0], score[0], "AP @ .5", color[1], score[1])
|
36 |
metrics = [
|
|
|
1 |
import contextlib
|
2 |
import io
|
3 |
+
from typing import Dict
|
4 |
|
5 |
import numpy as np
|
6 |
from pycocotools.coco import COCO
|
|
|
18 |
return coco_eval.stats
|
19 |
|
20 |
|
21 |
+
def make_ap_table(score: Dict[str, float], past_result=[], max_result=None, epoch=-1):
|
22 |
ap_table = Table()
|
23 |
ap_table.add_column("Epoch", justify="center", style="white", width=5)
|
24 |
ap_table.add_column("Avg. Precision", justify="left", style="cyan")
|
|
|
31 |
if past_result:
|
32 |
ap_table.add_row()
|
33 |
|
34 |
+
color = np.where(max_result <= score, "[green]", "[red]")
|
35 |
|
36 |
this_ap = ("AP @ .5:.95", color[0], score[0], "AP @ .5", color[1], score[1])
|
37 |
metrics = [
|