henry000 commited on
Commit
8b3b3ef
·
1 Parent(s): 802cb12

✨ [New] use lightning framework to training!

Browse files
yolo/__init__.py CHANGED
@@ -2,18 +2,22 @@ from yolo.config.config import Config, NMSConfig
2
  from yolo.model.yolo import create_model
3
  from yolo.tools.data_loader import AugmentationComposer, create_dataloader
4
  from yolo.tools.drawer import draw_bboxes
5
- from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
6
  from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms, create_converter
7
  from yolo.utils.deploy_utils import FastModelLoader
8
- from yolo.utils.logging_utils import ProgressLogger, custom_logger
 
 
 
 
9
  from yolo.utils.model_utils import PostProccess
10
 
11
  all = [
12
  "create_model",
13
  "Config",
14
- "ProgressLogger",
15
  "NMSConfig",
16
- "custom_logger",
17
  "validate_log_directory",
18
  "draw_bboxes",
19
  "Vec2Box",
@@ -21,10 +25,9 @@ all = [
21
  "bbox_nms",
22
  "create_converter",
23
  "AugmentationComposer",
 
24
  "create_dataloader",
25
  "FastModelLoader",
26
- "ModelTester",
27
- "ModelTrainer",
28
- "ModelValidator",
29
  "PostProccess",
30
  ]
 
2
  from yolo.model.yolo import create_model
3
  from yolo.tools.data_loader import AugmentationComposer, create_dataloader
4
  from yolo.tools.drawer import draw_bboxes
5
+ from yolo.tools.solver import TrainModel
6
  from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms, create_converter
7
  from yolo.utils.deploy_utils import FastModelLoader
8
+ from yolo.utils.logging_utils import (
9
+ ImageLogger,
10
+ YOLORichModelSummary,
11
+ YOLORichProgressBar,
12
+ )
13
  from yolo.utils.model_utils import PostProccess
14
 
15
  all = [
16
  "create_model",
17
  "Config",
18
+ "YOLORichProgressBar",
19
  "NMSConfig",
20
+ "YOLORichModelSummary",
21
  "validate_log_directory",
22
  "draw_bboxes",
23
  "Vec2Box",
 
25
  "bbox_nms",
26
  "create_converter",
27
  "AugmentationComposer",
28
+ "ImageLogger",
29
  "create_dataloader",
30
  "FastModelLoader",
31
+ "TrainModel",
 
 
32
  "PostProccess",
33
  ]
yolo/lazy.py CHANGED
@@ -2,41 +2,36 @@ import sys
2
  from pathlib import Path
3
 
4
  import hydra
 
5
 
6
  project_root = Path(__file__).resolve().parent.parent
7
  sys.path.append(str(project_root))
8
 
9
  from yolo.config.config import Config
10
- from yolo.model.yolo import create_model
11
- from yolo.tools.data_loader import create_dataloader
12
- from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
13
- from yolo.utils.bounding_box_utils import create_converter
14
- from yolo.utils.deploy_utils import FastModelLoader
15
- from yolo.utils.logging_utils import ProgressLogger
16
- from yolo.utils.model_utils import get_device
17
 
18
 
19
  @hydra.main(config_path="config", config_name="config", version_base=None)
20
  def main(cfg: Config):
21
- progress = ProgressLogger(cfg, exp_name=cfg.name)
22
- device, use_ddp = get_device(cfg.device)
23
- dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task, use_ddp)
24
- if getattr(cfg.task, "fast_inference", False):
25
- model = FastModelLoader(cfg).load_model(device)
26
- else:
27
- model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight)
28
- model = model.to(device)
29
-
30
- converter = create_converter(cfg.model.name, model, cfg.model.anchor, cfg.image_size, device)
31
-
32
- if cfg.task.task == "train":
33
- solver = ModelTrainer(cfg, model, converter, progress, device, use_ddp)
34
- if cfg.task.task == "validation":
35
- solver = ModelValidator(cfg.task, cfg.dataset, model, converter, progress, device)
36
- if cfg.task.task == "inference":
37
- solver = ModelTester(cfg, model, converter, progress, device)
38
- progress.start()
39
- solver.solve(dataloader)
40
 
41
 
42
  if __name__ == "__main__":
 
2
  from pathlib import Path
3
 
4
  import hydra
5
+ from lightning import Trainer
6
 
7
  project_root = Path(__file__).resolve().parent.parent
8
  sys.path.append(str(project_root))
9
 
10
  from yolo.config.config import Config
11
+ from yolo.tools.solver import TrainModel, ValidateModel
12
+ from yolo.utils.logging_utils import setup
 
 
 
 
 
13
 
14
 
15
  @hydra.main(config_path="config", config_name="config", version_base=None)
16
  def main(cfg: Config):
17
+ callbacks, loggers = setup(cfg)
18
+
19
+ trainer = Trainer(
20
+ accelerator="cuda",
21
+ max_epochs=getattr(cfg.task, "epoch", None),
22
+ precision="16-mixed",
23
+ callbacks=callbacks,
24
+ logger=loggers,
25
+ log_every_n_steps=1,
26
+ )
27
+
28
+ match cfg.task.task:
29
+ case "train":
30
+ model = TrainModel(cfg)
31
+ trainer.fit(model)
32
+ case "validation":
33
+ model = ValidateModel(cfg)
34
+ trainer.validate(model)
 
35
 
36
 
37
  if __name__ == "__main__":
yolo/tools/solver.py CHANGED
@@ -1,267 +1,89 @@
1
- import contextlib
2
- import io
3
- import json
4
- import os
5
- import time
6
- from collections import defaultdict
7
- from pathlib import Path
8
- from typing import Dict, Optional
9
 
10
- import torch
11
- from pycocotools.coco import COCO
12
- from torch import Tensor, distributed
13
- from torch.cuda.amp import GradScaler, autocast
14
- from torch.nn.parallel import DistributedDataParallel as DDP
15
- from torch.utils.data import DataLoader
16
-
17
- from yolo.config.config import Config, DatasetConfig, TrainConfig, ValidationConfig
18
- from yolo.model.yolo import YOLO
19
- from yolo.tools.data_loader import StreamDataLoader, create_dataloader
20
- from yolo.tools.drawer import draw_bboxes, draw_model
21
  from yolo.tools.loss_functions import create_loss_function
22
- from yolo.utils.bounding_box_utils import Vec2Box, calculate_map
23
- from yolo.utils.dataset_utils import locate_label_paths
24
- from yolo.utils.logger import logger
25
- from yolo.utils.logging_utils import ProgressLogger, log_model_structure
26
- from yolo.utils.model_utils import (
27
- ExponentialMovingAverage,
28
- PostProccess,
29
- collect_prediction,
30
- create_optimizer,
31
- create_scheduler,
32
- predicts_to_json,
33
- )
34
- from yolo.utils.solver_utils import calculate_ap
35
-
36
 
37
- class ModelTrainer:
38
- def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device, use_ddp: bool):
39
- train_cfg: TrainConfig = cfg.task
40
- self.model = model if not use_ddp else DDP(model, device_ids=[device])
41
- self.use_ddp = use_ddp
42
- self.vec2box = vec2box
43
- self.device = device
44
- self.optimizer = create_optimizer(model, train_cfg.optimizer)
45
- self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
46
- self.loss_fn = create_loss_function(cfg, vec2box)
47
- self.progress = progress
48
- self.num_epochs = cfg.task.epoch
49
- self.mAPs_dict = defaultdict(list)
50
 
51
- self.weights_dir = self.progress.save_path / "weights"
52
- self.weights_dir.mkdir(exist_ok=True)
 
 
53
 
54
- if not progress.quite_mode:
55
- log_model_structure(model.model)
56
- draw_model(model=model)
57
 
58
- self.validation_dataloader = create_dataloader(
59
- cfg.task.validation.data, cfg.dataset, cfg.task.validation.task, use_ddp
60
- )
61
- self.validator = ModelValidator(cfg.task.validation, cfg.dataset, model, vec2box, progress, device)
62
 
63
- if getattr(train_cfg.ema, "enabled", False):
64
- self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
 
 
 
 
65
  else:
66
- self.ema = None
67
- self.scaler = GradScaler()
68
-
69
- def train_one_batch(self, images: Tensor, targets: Tensor):
70
- images, targets = images.to(self.device), targets.to(self.device)
71
- self.optimizer.zero_grad()
72
-
73
- with autocast():
74
- predicts = self.model(images)
75
- aux_predicts = self.vec2box(predicts["AUX"])
76
- main_predicts = self.vec2box(predicts["Main"])
77
- loss, loss_item = self.loss_fn(aux_predicts, main_predicts, targets)
78
-
79
- self.scaler.scale(loss).backward()
80
- self.scaler.unscale_(self.optimizer)
81
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
82
- self.scaler.step(self.optimizer)
83
- self.scaler.update()
84
-
85
- return loss_item
86
-
87
- def train_one_epoch(self, dataloader):
88
- self.model.train()
89
- total_loss = defaultdict(float)
90
- total_samples = 0
91
- self.optimizer.next_epoch(len(dataloader))
92
- for batch_size, images, targets, *_ in dataloader:
93
- self.optimizer.next_batch()
94
- loss_each = self.train_one_batch(images, targets)
95
-
96
- for loss_name, loss_val in loss_each.items():
97
- if self.use_ddp: # collecting loss for each batch
98
- distributed.all_reduce(loss_val, op=distributed.ReduceOp.AVG)
99
- total_loss[loss_name] += loss_val.item() * batch_size
100
- total_samples += batch_size
101
- self.progress.one_batch(loss_each)
102
-
103
- for loss_val in total_loss.values():
104
- loss_val /= total_samples
105
-
106
- if self.scheduler:
107
- self.scheduler.step()
108
-
109
- return total_loss
110
-
111
- def save_checkpoint(self, epoch_idx: int, file_name: Optional[str] = None):
112
- file_name = file_name or f"E{epoch_idx:03d}.pt"
113
- file_path = self.weights_dir / file_name
114
-
115
- checkpoint = {
116
- "epoch": epoch_idx,
117
- "model_state_dict": self.model.state_dict(),
118
- "optimizer_state_dict": self.optimizer.state_dict(),
119
- }
120
- if self.ema:
121
- self.ema.apply_shadow()
122
- checkpoint["model_state_dict_ema"] = self.model.state_dict()
123
- self.ema.restore()
124
-
125
- logger.info(f"💾 success save at {file_path}")
126
- torch.save(checkpoint, file_path)
127
-
128
- def good_epoch(self, mAPs: Dict[str, Tensor]) -> bool:
129
- save_flag = True
130
- for mAP_key, mAP_val in mAPs.items():
131
- self.mAPs_dict[mAP_key].append(mAP_val)
132
- if mAP_val < max(self.mAPs_dict[mAP_key]):
133
- save_flag = False
134
- return save_flag
135
-
136
- def solve(self, dataloader: DataLoader):
137
- logger.info("🚄 Start Training!")
138
- num_epochs = self.num_epochs
139
-
140
- self.progress.start_train(num_epochs)
141
- for epoch_idx in range(num_epochs):
142
- if self.use_ddp:
143
- dataloader.sampler.set_epoch(epoch_idx)
144
 
145
- self.progress.start_one_epoch(len(dataloader), "Train", self.optimizer, epoch_idx)
146
- epoch_loss = self.train_one_epoch(dataloader)
147
- self.progress.finish_one_epoch(epoch_loss, epoch_idx=epoch_idx)
148
-
149
- mAPs = self.validator.solve(self.validation_dataloader, epoch_idx=epoch_idx)
150
- if mAPs is not None and self.good_epoch(mAPs):
151
- self.save_checkpoint(epoch_idx=epoch_idx)
152
- # TODO: save model if result are better than before
153
- self.progress.finish_train()
154
-
155
-
156
- class ModelTester:
157
- def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device):
158
- self.model = model
159
- self.device = device
160
- self.progress = progress
161
-
162
- self.post_proccess = PostProccess(vec2box, cfg.task.nms)
163
- self.save_path = progress.save_path / "images"
164
- os.makedirs(self.save_path, exist_ok=True)
165
- self.save_predict = getattr(cfg.task, "save_predict", None)
166
- self.idx2label = cfg.dataset.class_list
167
-
168
- def solve(self, dataloader: StreamDataLoader):
169
- logger.info("👀 Start Inference!")
170
- if isinstance(self.model, torch.nn.Module):
171
- self.model.eval()
172
-
173
- if dataloader.is_stream:
174
- import cv2
175
- import numpy as np
176
-
177
- last_time = time.time()
178
- try:
179
- for idx, (images, rev_tensor, origin_frame) in enumerate(dataloader):
180
- images = images.to(self.device)
181
- rev_tensor = rev_tensor.to(self.device)
182
- with torch.no_grad():
183
- predicts = self.model(images)
184
- predicts = self.post_proccess(predicts, rev_tensor)
185
- img = draw_bboxes(origin_frame, predicts, idx2label=self.idx2label)
186
-
187
- if dataloader.is_stream:
188
- img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
189
- fps = 1 / (time.time() - last_time)
190
- cv2.putText(img, f"FPS: {fps:.2f}", (0, 15), 0, 0.5, (100, 255, 0), 1, cv2.LINE_AA)
191
- last_time = time.time()
192
- cv2.imshow("Prediction", img)
193
- if cv2.waitKey(1) & 0xFF == ord("q"):
194
- break
195
- if not self.save_predict:
196
- continue
197
- if self.save_predict != False:
198
- save_image_path = self.save_path / f"frame{idx:03d}.png"
199
- img.save(save_image_path)
200
- logger.info(f"💾 Saved visualize image at {save_image_path}")
201
-
202
- except (KeyboardInterrupt, Exception) as e:
203
- dataloader.stop_event.set()
204
- dataloader.stop()
205
- if isinstance(e, KeyboardInterrupt):
206
- logger.error("User Keyboard Interrupt")
207
- else:
208
- raise e
209
- dataloader.stop()
210
-
211
-
212
- class ModelValidator:
213
- def __init__(
214
- self,
215
- validation_cfg: ValidationConfig,
216
- dataset_cfg: DatasetConfig,
217
- model: YOLO,
218
- vec2box: Vec2Box,
219
- progress: ProgressLogger,
220
- device,
221
- ):
222
- self.model = model
223
- self.device = device
224
- self.progress = progress
225
-
226
- self.post_proccess = PostProccess(vec2box, validation_cfg.nms)
227
- self.json_path = self.progress.save_path / "predict.json"
228
-
229
- with contextlib.redirect_stdout(io.StringIO()):
230
- # TODO: load with config file
231
- json_path, _ = locate_label_paths(Path(dataset_cfg.path), dataset_cfg.get("validation", "val"))
232
- if json_path:
233
- self.coco_gt = COCO(json_path)
234
-
235
- def solve(self, dataloader, epoch_idx=1):
236
- # logger.info("🧪 Start Validation!")
237
- self.model.eval()
238
- predict_json, mAPs = [], defaultdict(list)
239
- self.progress.start_one_epoch(len(dataloader), task="Validate")
240
- for batch_size, images, targets, rev_tensor, img_paths in dataloader:
241
- images, targets, rev_tensor = images.to(self.device), targets.to(self.device), rev_tensor.to(self.device)
242
- with torch.no_grad():
243
- predicts = self.model(images)
244
- predicts = self.post_proccess(predicts)
245
- for idx, predict in enumerate(predicts):
246
- mAP = calculate_map(predict, targets[idx])
247
- for mAP_key, mAP_val in mAP.items():
248
- mAPs[mAP_key].append(mAP_val)
249
-
250
- avg_mAPs = {key: 100 * torch.mean(torch.stack(val)) for key, val in mAPs.items()}
251
- self.progress.one_batch(avg_mAPs)
252
 
253
- predict_json.extend(predicts_to_json(img_paths, predicts, rev_tensor))
254
- self.progress.finish_one_epoch(avg_mAPs, epoch_idx=epoch_idx)
255
- self.progress.visualize_image(images, targets, predicts, epoch_idx=epoch_idx)
256
 
257
- with open(self.json_path, "w") as f:
258
- predict_json = collect_prediction(predict_json, self.progress.local_rank)
259
- if self.progress.local_rank != 0:
260
- return
261
- json.dump(predict_json, f)
262
- if hasattr(self, "coco_gt"):
263
- self.progress.start_pycocotools()
264
- result = calculate_ap(self.coco_gt, predict_json)
265
- self.progress.finish_pycocotools(result, epoch_idx)
266
 
267
- return avg_mAPs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lightning import LightningModule
2
+ from torchmetrics.detection import MeanAveragePrecision
 
 
 
 
 
 
3
 
4
+ from yolo.config.config import Config
5
+ from yolo.model.yolo import create_model
6
+ from yolo.tools.data_loader import create_dataloader
 
 
 
 
 
 
 
 
7
  from yolo.tools.loss_functions import create_loss_function
8
+ from yolo.utils.bounding_box_utils import create_converter, to_metrics_format
9
+ from yolo.utils.model_utils import PostProccess, create_optimizer, create_scheduler
 
 
 
 
 
 
 
 
 
 
 
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ class BaseModel(LightningModule):
13
+ def __init__(self, cfg: Config):
14
+ super().__init__()
15
+ self.model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight)
16
 
17
+ def forward(self, x):
18
+ return self.model(x)
 
19
 
 
 
 
 
20
 
21
+ class ValidateModel(BaseModel):
22
+ def __init__(self, cfg: Config):
23
+ super().__init__(cfg)
24
+ self.cfg = cfg
25
+ if self.cfg.task.task == "validation":
26
+ self.validation_cfg = self.cfg.task
27
  else:
28
+ self.validation_cfg = self.cfg.task.validation
29
+ self.metric = MeanAveragePrecision(iou_type="bbox", box_format="xyxy")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ def setup(self, stage):
32
+ self.vec2box = create_converter(
33
+ self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device
34
+ )
35
+ self.post_proccess = PostProccess(self.vec2box, self.validation_cfg.nms)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ def val_dataloader(self):
38
+ return create_dataloader(self.validation_cfg.data, self.cfg.dataset, self.validation_cfg.task)
 
39
 
40
+ def validation_step(self, batch, batch_idx):
41
+ batch_size, images, targets, rev_tensor, img_paths = batch
42
+ predicts = self.post_proccess(self(images))
43
+ batch_metrics = self.metric(
44
+ [to_metrics_format(predict) for predict in predicts], [to_metrics_format(target) for target in targets]
45
+ )
 
 
 
46
 
47
+ self.log_dict(
48
+ {
49
+ "map": batch_metrics["map"],
50
+ "map_50": batch_metrics["map_50"],
51
+ },
52
+ on_step=True,
53
+ prog_bar=True,
54
+ logger=False,
55
+ batch_size=batch_size,
56
+ )
57
+ return predicts
58
+
59
+ def on_validation_epoch_end(self):
60
+ epoch_metrics = self.metric.compute()
61
+ del epoch_metrics["classes"]
62
+ self.log_dict(epoch_metrics, on_epoch=True, prog_bar=True, logger=True)
63
+
64
+
65
+ class TrainModel(ValidateModel):
66
+ def __init__(self, cfg: Config):
67
+ super().__init__(cfg)
68
+ self.cfg = cfg
69
+
70
+ def setup(self, stage):
71
+ super().setup(stage)
72
+ self.loss_fn = create_loss_function(self.cfg, self.vec2box)
73
+
74
+ def train_dataloader(self):
75
+ return create_dataloader(self.cfg.task.data, self.cfg.dataset, self.cfg.task.task)
76
+
77
+ def training_step(self, batch, batch_idx):
78
+ batch_size, images, targets, *_ = batch
79
+ predicts = self(images)
80
+ aux_predicts = self.vec2box(predicts["AUX"])
81
+ main_predicts = self.vec2box(predicts["Main"])
82
+ loss, loss_item = self.loss_fn(aux_predicts, main_predicts, targets)
83
+ self.log_dict(loss_item, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
84
+ return loss * batch_size
85
+
86
+ def configure_optimizers(self):
87
+ optimizer = create_optimizer(self.model, self.cfg.task.optimizer)
88
+ scheduler = create_scheduler(optimizer, self.cfg.task.scheduler)
89
+ return [optimizer], [scheduler]
yolo/utils/bounding_box_utils.py CHANGED
@@ -446,3 +446,10 @@ def calculate_map(predictions, ground_truths, iou_thresholds=arange(0.5, 1, 0.05
446
  "mAP.5:.95": torch.mean(torch.stack(aps)),
447
  }
448
  return mAP
 
 
 
 
 
 
 
 
446
  "mAP.5:.95": torch.mean(torch.stack(aps)),
447
  }
448
  return mAP
449
+
450
+
451
+ def to_metrics_format(prediction: Tensor) -> Dict[str, Union[float, Tensor]]:
452
+ bbox = {"boxes": prediction[:, 1:5], "labels": prediction[:, 0].int()}
453
+ if prediction.size(1) == 6:
454
+ bbox["scores"] = prediction[:, 5]
455
+ return bbox
yolo/utils/logging_utils.py CHANGED
@@ -11,9 +11,7 @@ Example:
11
  custom_logger()
12
  """
13
 
14
- import os
15
- import random
16
- import sys
17
  from collections import deque
18
  from logging import FileHandler
19
  from pathlib import Path
@@ -22,39 +20,29 @@ from typing import Any, Dict, List, Optional, Tuple, Union
22
  import numpy as np
23
  import torch
24
  import wandb
25
- import wandb.errors.term
 
 
 
26
  from omegaconf import ListConfig
 
27
  from rich.console import Console, Group
28
- from rich.progress import (
29
- BarColumn,
30
- Progress,
31
- SpinnerColumn,
32
- TextColumn,
33
- TimeRemainingColumn,
34
- )
35
  from rich.table import Table
 
36
  from torch import Tensor
37
  from torch.nn import ModuleList
38
- from torch.optim import Optimizer
39
- from torchvision.transforms.functional import pil_to_tensor
40
 
41
  from yolo.config.config import Config, YOLOLayer
42
  from yolo.model.yolo import YOLO
43
- from yolo.tools.drawer import draw_bboxes
44
  from yolo.utils.logger import logger
45
  from yolo.utils.solver_utils import make_ap_table
46
 
47
 
48
- def custom_logger(quite: bool = False):
49
- if quite:
50
- logger.removeHandler("YOLO_logger")
51
-
52
-
53
  # TODO: should be moved to correct position
54
  def set_seed(seed):
55
- random.seed(seed)
56
- np.random.seed(seed)
57
- torch.manual_seed(seed)
58
  if torch.cuda.is_available():
59
  torch.cuda.manual_seed(seed)
60
  torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
@@ -62,189 +50,211 @@ def set_seed(seed):
62
  torch.backends.cudnn.benchmark = False
63
 
64
 
65
- class ProgressLogger(Progress):
66
- def __init__(self, cfg: Config, exp_name: str, *args, **kwargs):
67
- set_seed(cfg.lucky_number)
68
- self.local_rank = int(os.getenv("LOCAL_RANK", "0"))
69
- self.quite_mode = self.local_rank or getattr(cfg, "quite", False)
70
- custom_logger(self.quite_mode)
71
- self.save_path = validate_log_directory(cfg, exp_name=cfg.name)
72
-
73
- progress_bar = (
74
- SpinnerColumn(),
75
- TextColumn("[progress.description]{task.description}"),
76
- BarColumn(bar_width=None),
77
- TextColumn("{task.completed:.0f}/{task.total:.0f}"),
78
- TimeRemainingColumn(),
79
- )
80
- self.ap_table = Table()
81
- # TODO: load maxlen by config files
82
- self.ap_past_list = deque(maxlen=5)
83
- self.last_result = 0
84
- super().__init__(*args, *progress_bar, **kwargs)
85
-
86
- self.use_wandb = cfg.use_wandb
87
- if self.use_wandb and self.local_rank == 0:
88
- wandb.errors.term._log = custom_wandb_log
89
- self.wandb = wandb.init(
90
- project="YOLO", resume="allow", mode="online", dir=self.save_path, id=None, name=exp_name
91
- )
92
 
93
- self.use_tensorboard = cfg.use_tensorboard
94
- if self.use_tensorboard and self.local_rank == 0:
95
- from torch.utils.tensorboard import SummaryWriter
96
 
97
- self.tb_writer = SummaryWriter(log_dir=self.save_path / "tensorboard")
98
- logger.info(f"📍 Enable TensorBoard locally at <blue><u>http://localhost:6006</></>")
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- def rank_check(logging_function):
101
- def wrapper(self, *args, **kwargs):
102
- if getattr(self, "local_rank", 0) != 0:
103
- return
104
- return logging_function(self, *args, **kwargs)
105
 
106
- return wrapper
 
 
107
 
108
- def get_renderable(self):
109
- renderable = Group(*self.get_renderables(), self.ap_table)
110
- return renderable
111
 
112
- @rank_check
113
- def start_train(self, num_epochs: int):
114
- self.task_epoch = self.add_task(f"[cyan]Start Training {num_epochs} epochs", total=num_epochs)
115
- self.update(self.task_epoch, advance=-0.5)
116
-
117
- @rank_check
118
- def start_one_epoch(
119
- self, num_batches: int, task: str = "Train", optimizer: Optimizer = None, epoch_idx: int = None
120
- ):
121
- self.num_batches = num_batches
122
- self.task = task
123
- if hasattr(self, "task_epoch"):
124
- self.update(self.task_epoch, description=f"[cyan] Preparing Data")
125
-
126
- if optimizer is not None:
127
- lr_values = [params["lr"] for params in optimizer.param_groups]
128
- lr_names = ["Learning Rate/bias", "Learning Rate/norm", "Learning Rate/conv"]
129
- if self.use_wandb:
130
- for lr_name, lr_value in zip(lr_names, lr_values):
131
- self.wandb.log({lr_name: lr_value}, step=epoch_idx)
132
-
133
- if self.use_tensorboard:
134
- for lr_name, lr_value in zip(lr_names, lr_values):
135
- self.tb_writer.add_scalar(lr_name, lr_value, global_step=epoch_idx)
136
-
137
- self.batch_task = self.add_task(f"[green] Phase: {task}", total=num_batches)
138
-
139
- @rank_check
140
- def one_batch(self, batch_info: Dict[str, Tensor] = None):
141
- epoch_descript = "[cyan]" + self.task + "[white] |"
142
- batch_descript = "|"
143
- if self.task == "Train":
144
- self.update(self.task_epoch, advance=1 / self.num_batches)
145
- for info_name, info_val in batch_info.items():
146
- epoch_descript += f"{info_name: ^9}|"
147
- batch_descript += f" {info_val:2.2f} |"
148
- self.update(self.batch_task, advance=1, description=f"[green]{self.task} [white]{batch_descript}")
149
- if hasattr(self, "task_epoch"):
150
- self.update(self.task_epoch, description=epoch_descript)
151
-
152
- @rank_check
153
- def finish_one_epoch(self, batch_info: Dict[str, Any] = None, epoch_idx: int = -1):
154
- if self.task == "Train":
155
- prefix = "Loss"
156
- elif self.task == "Validate":
157
- prefix = "Metrics"
158
- batch_info = {f"{prefix}/{key}": value for key, value in batch_info.items()}
159
- if self.use_wandb:
160
- self.wandb.log(batch_info, step=epoch_idx)
161
- if self.use_tensorboard:
162
- for key, value in batch_info.items():
163
- self.tb_writer.add_scalar(key, value, epoch_idx)
164
-
165
- self.remove_task(self.batch_task)
166
-
167
- @rank_check
168
- def visualize_image(
169
- self,
170
- images: Optional[Tensor] = None,
171
- ground_truth: Optional[Tensor] = None,
172
- prediction: Optional[Union[List[Tensor], Tensor]] = None,
173
- epoch_idx: int = 0,
174
- ) -> None:
175
- """
176
- Upload the ground truth bounding boxes, predicted bounding boxes, and the original image to wandb or TensorBoard.
177
-
178
- Args:
179
- images (Optional[Tensor]): Tensor of images with shape (BZ, 3, 640, 640).
180
- ground_truth (Optional[Tensor]): Ground truth bounding boxes with shape (BZ, N, 5) or (N, 5). Defaults to None.
181
- prediction (prediction: Optional[Union[List[Tensor], Tensor]]): List of predicted bounding boxes with shape (N, 6) or (N, 6). Defaults to None.
182
- epoch_idx (int): Current epoch index. Defaults to 0.
183
- """
184
- if images is not None:
185
- images = images[0] if images.ndim == 4 else images
186
- if self.use_wandb:
187
- wandb.log({"Input Image": wandb.Image(images)}, step=epoch_idx)
188
- if self.use_tensorboard:
189
- self.tb_writer.add_image("Media/Input Image", images, 1)
190
-
191
- if ground_truth is not None:
192
- gt_boxes = ground_truth[0] if ground_truth.ndim == 3 else ground_truth
193
- if self.use_wandb:
194
- wandb.log(
195
- {"Ground Truth": wandb.Image(images, boxes={"predictions": {"box_data": log_bbox(gt_boxes)}})},
196
- step=epoch_idx,
197
- )
198
- if self.use_tensorboard:
199
- self.tb_writer.add_image("Media/Ground Truth", pil_to_tensor(draw_bboxes(images, gt_boxes)), epoch_idx)
200
-
201
- if prediction is not None:
202
- pred_boxes = prediction[0] if isinstance(prediction, list) else prediction
203
- if self.use_wandb:
204
- wandb.log(
205
- {"Prediction": wandb.Image(images, boxes={"predictions": {"box_data": log_bbox(pred_boxes)}})},
206
- step=epoch_idx,
207
- )
208
- if self.use_tensorboard:
209
- self.tb_writer.add_image("Media/Prediction", pil_to_tensor(draw_bboxes(images, pred_boxes)), epoch_idx)
210
-
211
- @rank_check
212
- def start_pycocotools(self):
213
- self.batch_task = self.add_task("[green]Run pycocotools", total=1)
214
-
215
- @rank_check
216
- def finish_pycocotools(self, result, epoch_idx=-1):
217
- ap_table, ap_main = make_ap_table(result * 100, self.ap_past_list, self.last_result, epoch_idx)
218
- self.last_result = np.maximum(result, self.last_result)
219
- self.ap_past_list.append((epoch_idx, ap_main))
220
- self.ap_table = ap_table
221
-
222
- if self.use_wandb:
223
- self.wandb.log({"PyCOCO/AP @ .5:.95": ap_main[2], "PyCOCO/AP @ .5": ap_main[5]})
224
- if self.use_tensorboard:
225
- # TODO: waiting torch bugs fix, https://github.com/pytorch/pytorch/issues/32651
226
- self.tb_writer.add_scalar("PyCOCO/AP @ .5:.95", ap_main[2], epoch_idx)
227
- self.tb_writer.add_scalar("PyCOCO/AP @ .5", ap_main[5], epoch_idx)
228
-
229
- self.update(self.batch_task, advance=1)
230
  self.refresh()
231
- self.remove_task(self.batch_task)
232
 
233
- @rank_check
234
- def finish_train(self):
235
- self.remove_task(self.task_epoch)
236
- self.stop()
237
- if self.use_wandb:
238
- self.wandb.finish()
239
- if self.use_tensorboard:
240
- self.tb_writer.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
 
 
 
 
 
242
 
243
- def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=True, silent=False):
244
- if silent:
245
- return
246
- for line in string.split("\n"):
247
- logger.info("🌐 " + line)
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
 
250
  def log_model_structure(model: Union[ModuleList, YOLOLayer, YOLO]):
@@ -291,7 +301,7 @@ def validate_log_directory(cfg: Config, exp_name: str) -> Path:
291
  )
292
 
293
  save_path.mkdir(parents=True, exist_ok=True)
294
- logger.info(f"📄 Created log folder: [bold gray]{save_path}[/]", extra={"markup": True})
295
  logger.addHandler(FileHandler(save_path / "output.log"))
296
  return save_path
297
 
@@ -327,4 +337,4 @@ def log_bbox(
327
  bbox_entry["scores"] = {"confidence": conf[0]}
328
  bbox_list.append(bbox_entry)
329
 
330
- return bbox_list
 
11
  custom_logger()
12
  """
13
 
14
+ import logging
 
 
15
  from collections import deque
16
  from logging import FileHandler
17
  from pathlib import Path
 
20
  import numpy as np
21
  import torch
22
  import wandb
23
+ from lightning import LightningModule, Trainer, seed_everything
24
+ from lightning.pytorch.callbacks import Callback, RichModelSummary, RichProgressBar
25
+ from lightning.pytorch.callbacks.progress.rich_progress import CustomProgress
26
+ from lightning.pytorch.loggers import WandbLogger
27
  from omegaconf import ListConfig
28
+ from rich import reconfigure
29
  from rich.console import Console, Group
30
+ from rich.logging import RichHandler
 
 
 
 
 
 
31
  from rich.table import Table
32
+ from rich.text import Text
33
  from torch import Tensor
34
  from torch.nn import ModuleList
35
+ from typing_extensions import override
 
36
 
37
  from yolo.config.config import Config, YOLOLayer
38
  from yolo.model.yolo import YOLO
 
39
  from yolo.utils.logger import logger
40
  from yolo.utils.solver_utils import make_ap_table
41
 
42
 
 
 
 
 
 
43
  # TODO: should be moved to correct position
44
  def set_seed(seed):
45
+ seed_everything(seed)
 
 
46
  if torch.cuda.is_available():
47
  torch.cuda.manual_seed(seed)
48
  torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
 
50
  torch.backends.cudnn.benchmark = False
51
 
52
 
53
+ class YOLOCustomProgress(CustomProgress):
54
+ def get_renderable(self):
55
+ renderable = Group(*self.get_renderables())
56
+ if hasattr(self, "table"):
57
+ renderable = Group(*self.get_renderables(), self.table)
58
+ return renderable
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
 
 
 
60
 
61
+ class YOLORichProgressBar(RichProgressBar):
62
+ @override
63
+ def _init_progress(self, trainer: "Trainer") -> None:
64
+ if self.is_enabled and (self.progress is None or self._progress_stopped):
65
+ self._reset_progress_bar_ids()
66
+ reconfigure(**self._console_kwargs)
67
+ self._console = Console()
68
+ self._console.clear_live()
69
+ self.progress = YOLOCustomProgress(
70
+ *self.configure_columns(trainer),
71
+ auto_refresh=False,
72
+ disable=self.is_disabled,
73
+ console=self._console,
74
+ )
75
+ self.progress.start()
76
 
77
+ self._progress_stopped = False
 
 
 
 
78
 
79
+ self.max_result = 0
80
+ self.past_results = deque(maxlen=5)
81
+ self.progress.table = Table()
82
 
83
+ @override
84
+ def _get_train_description(self, current_epoch: int) -> str:
85
+ return Text("[cyan]Train [white]|")
86
 
87
+ @override
88
+ def on_train_start(self, trainer, pl_module):
89
+ self._init_progress(trainer)
90
+ num_epochs = trainer.max_epochs - 1
91
+ self.task_epoch = self._add_task(
92
+ total_batches=num_epochs,
93
+ description=f"[cyan]Start Training {num_epochs} epochs",
94
+ )
95
+ self.max_result = 0
96
+ self.past_results.clear()
97
+ self.progress.update(self.task_epoch, advance=-0.5)
98
+
99
+ @override
100
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch: Any, batch_idx: int):
101
+ self._update(self.train_progress_bar_id, batch_idx + 1)
102
+ self._update_metrics(trainer, pl_module)
103
+ epoch_descript = "[cyan]Train [white]|"
104
+ batch_descript = "[green]Train [white]|"
105
+ metrics = self.get_metrics(trainer, pl_module)
106
+ metrics.pop("v_num")
107
+ for metrics_name, metrics_val in metrics.items():
108
+ if "Loss_step" in metrics_name:
109
+ epoch_descript += f"{metrics_name.removesuffix('_step'): ^9}|"
110
+ batch_descript += f" {metrics_val:2.2f} |"
111
+
112
+ self.progress.update(self.task_epoch, advance=1 / self.total_train_batches, description=epoch_descript)
113
+ self.progress.update(self.train_progress_bar_id, description=batch_descript)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  self.refresh()
 
115
 
116
+ @override
117
+ def on_train_end(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
118
+ self._update_metrics(trainer, pl_module)
119
+ self.progress.remove_task(self.train_progress_bar_id)
120
+ self.train_progress_bar_id = None
121
+
122
+ @override
123
+ def on_validation_end(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
124
+ if trainer.state.fn == "fit":
125
+ self._update_metrics(trainer, pl_module)
126
+ self.reset_dataloader_idx_tracker()
127
+ all_metrics = self.get_metrics(trainer, pl_module)
128
+
129
+ ap_ar_list = [
130
+ key
131
+ for key in all_metrics.keys()
132
+ if key.startswith(("map", "mar")) and not key.endswith(("_step", "_epoch"))
133
+ ]
134
+ score = np.array([all_metrics[key] for key in ap_ar_list]) * 100
135
+
136
+ self.progress.table, ap_main = make_ap_table(score, self.past_results, self.max_result, trainer.current_epoch)
137
+ self.max_result = np.maximum(score, self.max_result)
138
+ self.past_results.append((trainer.current_epoch, ap_main))
139
+
140
+ @override
141
+ def refresh(self) -> None:
142
+ if self.progress:
143
+ self.progress.refresh()
144
+
145
+ @property
146
+ def validation_description(self) -> str:
147
+ return "[green]Validation"
148
+
149
+
150
+ class YOLORichModelSummary(RichModelSummary):
151
+
152
+ from typing_extensions import override
153
+
154
+ @staticmethod
155
+ @override
156
+ def summarize(
157
+ summary_data: List[Tuple[str, List[str]]],
158
+ total_parameters: int,
159
+ trainable_parameters: int,
160
+ model_size: float,
161
+ total_training_modes: Dict[str, int],
162
+ **summarize_kwargs: Any,
163
+ ) -> None:
164
+ from lightning.pytorch.utilities.model_summary import get_human_readable_count
165
+ from rich import get_console
166
+ from rich.table import Table
167
+
168
+ console = get_console()
169
+
170
+ header_style: str = summarize_kwargs.get("header_style", "bold magenta")
171
+ table = Table(header_style=header_style)
172
+ table.add_column(" ", style="dim")
173
+ table.add_column("Name", justify="left", no_wrap=True)
174
+ table.add_column("Type")
175
+ table.add_column("Params", justify="right")
176
+ table.add_column("Mode")
177
+
178
+ column_names = list(zip(*summary_data))[0]
179
+
180
+ for column_name in ["In sizes", "Out sizes"]:
181
+ if column_name in column_names:
182
+ table.add_column(column_name, justify="right", style="white")
183
+
184
+ rows = list(zip(*(arr[1] for arr in summary_data)))
185
+ for row in rows:
186
+ table.add_row(*row)
187
+
188
+ console.print(table)
189
+
190
+ parameters = []
191
+ for param in [trainable_parameters, total_parameters - trainable_parameters, total_parameters, model_size]:
192
+ parameters.append("{:<{}}".format(get_human_readable_count(int(param)), 10))
193
+
194
+ grid = Table(header_style=header_style)
195
+ table.add_column(" ", style="dim")
196
+ grid.add_column("[bold]Attributes[/]")
197
+ grid.add_column("Value")
198
+
199
+ grid.add_row("[bold]Trainable params[/]", f"{parameters[0]}")
200
+ grid.add_row("[bold]Non-trainable params[/]", f"{parameters[1]}")
201
+ grid.add_row("[bold]Total params[/]", f"{parameters[2]}")
202
+ grid.add_row("[bold]Total estimated model params size (MB)[/]", f"{parameters[3]}")
203
+ grid.add_row("[bold]Modules in train mode[/]", f"{total_training_modes['train']}")
204
+ grid.add_row("[bold]Modules in eval mode[/]", f"{total_training_modes['eval']}")
205
+
206
+ console.print(grid)
207
+
208
+
209
+ class ImageLogger(Callback):
210
+ def on_validation_batch_end(self, trainer: Trainer, pl_module, outputs, batch, batch_idx) -> None:
211
+ if batch_idx != 0:
212
+ return
213
+ batch_size, images, targets, rev_tensor, img_paths = batch
214
+ gt_boxes = targets[0] if targets.ndim == 3 else targets
215
+ pred_boxes = outputs[0] if isinstance(outputs, list) else outputs
216
+ images = [images[0]]
217
+ step = trainer.current_epoch
218
+ for logger in trainer.loggers:
219
+ if isinstance(logger, WandbLogger):
220
+ logger.log_image("Input Image", images, step=step)
221
+ logger.log_image("Ground Truth", images, step=step, boxes=[log_bbox(gt_boxes)])
222
+ logger.log_image("Prediction", images, step=step, boxes=[log_bbox(pred_boxes)])
223
+
224
+
225
+ def setup(cfg: Config):
226
+ if hasattr(cfg, "quite"):
227
+ logger.removeHandler("YOLO_logger")
228
+ return
229
+
230
+ class EmojiFormatter(logging.Formatter):
231
+ def format(self, record):
232
+ return f":high_voltage: {super().format(record)}"
233
 
234
+ rich_handler = RichHandler(markup=True)
235
+ rich_handler.setFormatter(EmojiFormatter("%(message)s"))
236
+ lightning_logger = logging.getLogger("lightning.pytorch")
237
+ lightning_logger.handlers.clear()
238
+ lightning_logger.addHandler(rich_handler)
239
 
240
+ def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=True, silent=False):
241
+ if silent:
242
+ return
243
+ for line in string.split("\n"):
244
+ logger.info(Text.from_ansi(":globe_with_meridians: " + line))
245
+
246
+ wandb.errors.term._log = custom_wandb_log
247
+
248
+ save_path = validate_log_directory(cfg, cfg.name)
249
+
250
+ progress, loggers = [], []
251
+ progress.append(YOLORichProgressBar())
252
+ progress.append(YOLORichModelSummary())
253
+ progress.append(ImageLogger())
254
+
255
+ loggers.append(WandbLogger(project="YOLO", name=cfg.name, save_dir=save_path, id=None))
256
+
257
+ return progress, loggers
258
 
259
 
260
  def log_model_structure(model: Union[ModuleList, YOLOLayer, YOLO]):
 
301
  )
302
 
303
  save_path.mkdir(parents=True, exist_ok=True)
304
+ logger.info(f"📄 Created log folder: [blue b u]123{save_path}[/]")
305
  logger.addHandler(FileHandler(save_path / "output.log"))
306
  return save_path
307
 
 
337
  bbox_entry["scores"] = {"confidence": conf[0]}
338
  bbox_list.append(bbox_entry)
339
 
340
+ return {"predictions": {"box_data": bbox_list}}
yolo/utils/model_utils.py CHANGED
@@ -56,23 +56,8 @@ def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
56
  {"params": conv_params},
57
  {"params": norm_params, "weight_decay": 0},
58
  ]
59
-
60
- def next_epoch(self, batch_num):
61
- self.min_lr = self.max_lr
62
- self.max_lr = [param["lr"] for param in self.param_groups]
63
- self.batch_num = batch_num
64
- self.batch_idx = 0
65
-
66
- def next_batch(self):
67
- self.batch_idx += 1
68
- for lr_idx, param_group in enumerate(self.param_groups):
69
- min_lr, max_lr = self.min_lr[lr_idx], self.max_lr[lr_idx]
70
- param_group["lr"] = min_lr + (self.batch_idx) * (max_lr - min_lr) / self.batch_num
71
-
72
- optimizer_class.next_batch = next_batch
73
- optimizer_class.next_epoch = next_epoch
74
  optimizer = optimizer_class(model_parameters, **optim_cfg.args)
75
- optimizer.max_lr = [0.1, 0, 0]
76
  return optimizer
77
 
78
 
@@ -168,6 +153,7 @@ def predicts_to_json(img_paths, predicts, rev_tensor):
168
  batch_json = []
169
  for img_path, bboxes, box_reverse in zip(img_paths, predicts, rev_tensor):
170
  scale, shift = box_reverse.split([1, 4])
 
171
  bboxes[:, 1:5] = (bboxes[:, 1:5] - shift[None]) / scale[None]
172
  bboxes[:, 1:5] = transform_bbox(bboxes[:, 1:5], "xyxy -> xywh")
173
  for cls, *pos, conf in bboxes:
 
56
  {"params": conv_params},
57
  {"params": norm_params, "weight_decay": 0},
58
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  optimizer = optimizer_class(model_parameters, **optim_cfg.args)
60
+ # TODO: implement batch lr schedular when warm up
61
  return optimizer
62
 
63
 
 
153
  batch_json = []
154
  for img_path, bboxes, box_reverse in zip(img_paths, predicts, rev_tensor):
155
  scale, shift = box_reverse.split([1, 4])
156
+ bboxes = bboxes.clone()
157
  bboxes[:, 1:5] = (bboxes[:, 1:5] - shift[None]) / scale[None]
158
  bboxes[:, 1:5] = transform_bbox(bboxes[:, 1:5], "xyxy -> xywh")
159
  for cls, *pos, conf in bboxes:
yolo/utils/solver_utils.py CHANGED
@@ -1,5 +1,6 @@
1
  import contextlib
2
  import io
 
3
 
4
  import numpy as np
5
  from pycocotools.coco import COCO
@@ -17,7 +18,7 @@ def calculate_ap(coco_gt: COCO, pd_path):
17
  return coco_eval.stats
18
 
19
 
20
- def make_ap_table(score, past_result=[], last_score=None, epoch=-1):
21
  ap_table = Table()
22
  ap_table.add_column("Epoch", justify="center", style="white", width=5)
23
  ap_table.add_column("Avg. Precision", justify="left", style="cyan")
@@ -30,7 +31,7 @@ def make_ap_table(score, past_result=[], last_score=None, epoch=-1):
30
  if past_result:
31
  ap_table.add_row()
32
 
33
- color = np.where(last_score <= score, "[green]", "[red]")
34
 
35
  this_ap = ("AP @ .5:.95", color[0], score[0], "AP @ .5", color[1], score[1])
36
  metrics = [
 
1
  import contextlib
2
  import io
3
+ from typing import Dict
4
 
5
  import numpy as np
6
  from pycocotools.coco import COCO
 
18
  return coco_eval.stats
19
 
20
 
21
+ def make_ap_table(score: Dict[str, float], past_result=[], max_result=None, epoch=-1):
22
  ap_table = Table()
23
  ap_table.add_column("Epoch", justify="center", style="white", width=5)
24
  ap_table.add_column("Avg. Precision", justify="left", style="cyan")
 
31
  if past_result:
32
  ap_table.add_row()
33
 
34
+ color = np.where(max_result <= score, "[green]", "[red]")
35
 
36
  this_ap = ("AP @ .5:.95", color[0], score[0], "AP @ .5", color[1], score[1])
37
  metrics = [