File size: 13,655 Bytes
d5ba31a a7ef999 d5ba31a a7ef999 d5ba31a 8b3b3ef 999037c 0174b5b fa09d11 7f8fc3e 31cab2b fa09d11 67c056a f2370d7 8b3b3ef 1d2b161 c3b133e 1d2b161 999037c 8b3b3ef 0f9ffa2 8b3b3ef f2370d7 68d5954 8b3b3ef 0f9ffa2 b5fa3f1 68d5954 0174b5b c4cd90a 999037c d5ba31a 67c056a 8b3b3ef 67c056a 8b3b3ef 6e46676 7f8fc3e 8b3b3ef 1d2b161 8b3b3ef 7f8fc3e 8b3b3ef 1132d27 8b3b3ef 1132d27 8b3b3ef 999037c 8b3b3ef 1d2b161 8b3b3ef 1d2b161 8b3b3ef 1d2b161 8b3b3ef 999037c 6e46676 8b3b3ef 1d2b161 8b3b3ef 1d2b161 8b3b3ef 3ebbbd9 91a3f78 3ebbbd9 91a3f78 3e180a7 91a3f78 8b3b3ef 3ebbbd9 6e46676 8b3b3ef 3ebbbd9 c4cd90a 3ebbbd9 8b3b3ef de99a93 8b3b3ef 3ebbbd9 f2370d7 68d5954 0f9ffa2 c3b133e 0f9ffa2 16c6705 91a3f78 fa09d11 16c6705 b5fa3f1 16c6705 fa09d11 16c6705 fa09d11 16c6705 b038f54 3ebbbd9 0174b5b 16c6705 9c191b9 8b3b3ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 |
"""
Module for initializing logging tools used in machine learning and data processing.
Supports integration with Weights & Biases (wandb), Loguru, TensorBoard, and other
logging frameworks as needed.
This setup ensures consistent logging across various platforms, facilitating
effective monitoring and debugging.
Example:
from tools.logger import custom_logger
custom_logger()
"""
import logging
from collections import deque
from logging import FileHandler
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import wandb
from lightning import LightningModule, Trainer, seed_everything
from lightning.pytorch.callbacks import Callback, RichModelSummary, RichProgressBar
from lightning.pytorch.callbacks.progress.rich_progress import CustomProgress
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
from lightning.pytorch.utilities import rank_zero_only
from omegaconf import ListConfig
from rich import get_console, reconfigure
from rich.console import Console, Group
from rich.logging import RichHandler
from rich.table import Table
from rich.text import Text
from torch import Tensor
from torch.nn import ModuleList
from typing_extensions import override
from yolo.config.config import Config, YOLOLayer
from yolo.model.yolo import YOLO
from yolo.utils.logger import logger
from yolo.utils.model_utils import EMA
from yolo.utils.solver_utils import make_ap_table
# TODO: should be moved to correct position
def set_seed(seed):
seed_everything(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class YOLOCustomProgress(CustomProgress):
def get_renderable(self):
renderable = Group(*self.get_renderables())
if hasattr(self, "table"):
renderable = Group(*self.get_renderables(), self.table)
return renderable
class YOLORichProgressBar(RichProgressBar):
@override
@rank_zero_only
def _init_progress(self, trainer: "Trainer") -> None:
if self.is_enabled and (self.progress is None or self._progress_stopped):
self._reset_progress_bar_ids()
reconfigure(**self._console_kwargs)
self._console = Console()
self._console.clear_live()
self.progress = YOLOCustomProgress(
*self.configure_columns(trainer),
auto_refresh=False,
disable=self.is_disabled,
console=self._console,
)
self.progress.start()
self._progress_stopped = False
self.max_result = 0
self.past_results = deque(maxlen=5)
self.progress.table = Table()
@override
def _get_train_description(self, current_epoch: int) -> str:
return Text("[cyan]Train [white]|")
@override
@rank_zero_only
def on_train_start(self, trainer, pl_module):
self._init_progress(trainer)
num_epochs = trainer.max_epochs - 1
self.task_epoch = self._add_task(
total_batches=num_epochs,
description=f"[cyan]Start Training {num_epochs} epochs",
)
self.max_result = 0
self.past_results.clear()
@override
@rank_zero_only
def on_train_batch_end(self, trainer, pl_module, outputs, batch: Any, batch_idx: int):
self._update(self.train_progress_bar_id, batch_idx + 1)
self._update_metrics(trainer, pl_module)
epoch_descript = "[cyan]Train [white]|"
batch_descript = "[green]Train [white]|"
metrics = self.get_metrics(trainer, pl_module)
metrics.pop("v_num")
for metrics_name, metrics_val in metrics.items():
if "Loss_step" in metrics_name:
epoch_descript += f"{metrics_name.removesuffix('_step').split('/')[1]: ^9}|"
batch_descript += f" {metrics_val:2.2f} |"
self.progress.update(self.task_epoch, advance=1 / self.total_train_batches, description=epoch_descript)
self.progress.update(self.train_progress_bar_id, description=batch_descript)
self.refresh()
@override
@rank_zero_only
def on_train_end(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
self._update_metrics(trainer, pl_module)
self.progress.remove_task(self.train_progress_bar_id)
self.train_progress_bar_id = None
@override
@rank_zero_only
def on_validation_end(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
if trainer.state.fn == "fit":
self._update_metrics(trainer, pl_module)
self.reset_dataloader_idx_tracker()
all_metrics = self.get_metrics(trainer, pl_module)
ap_ar_list = [
key
for key in all_metrics.keys()
if key.startswith(("map", "mar")) and not key.endswith(("_step", "_epoch"))
]
score = np.array([all_metrics[key] for key in ap_ar_list]) * 100
self.progress.table, ap_main = make_ap_table(score, self.past_results, self.max_result, trainer.current_epoch)
self.max_result = np.maximum(score, self.max_result)
self.past_results.append((trainer.current_epoch, ap_main))
@override
def refresh(self) -> None:
if self.progress:
self.progress.refresh()
@property
def validation_description(self) -> str:
return "[green]Validation"
class YOLORichModelSummary(RichModelSummary):
@staticmethod
@override
def summarize(
summary_data: List[Tuple[str, List[str]]],
total_parameters: int,
trainable_parameters: int,
model_size: float,
total_training_modes: Dict[str, int],
**summarize_kwargs: Any,
) -> None:
from lightning.pytorch.utilities.model_summary import get_human_readable_count
console = get_console()
header_style: str = summarize_kwargs.get("header_style", "bold magenta")
table = Table(header_style=header_style)
table.add_column(" ", style="dim")
table.add_column("Name", justify="left", no_wrap=True)
table.add_column("Type")
table.add_column("Params", justify="right")
table.add_column("Mode")
column_names = list(zip(*summary_data))[0]
for column_name in ["In sizes", "Out sizes"]:
if column_name in column_names:
table.add_column(column_name, justify="right", style="white")
rows = list(zip(*(arr[1] for arr in summary_data)))
for row in rows:
table.add_row(*row)
console.print(table)
parameters = []
for param in [trainable_parameters, total_parameters - trainable_parameters, total_parameters, model_size]:
parameters.append("{:<{}}".format(get_human_readable_count(int(param)), 10))
grid = Table(header_style=header_style)
table.add_column(" ", style="dim")
grid.add_column("[bold]Attributes[/]")
grid.add_column("Value")
grid.add_row("[bold]Trainable params[/]", f"{parameters[0]}")
grid.add_row("[bold]Non-trainable params[/]", f"{parameters[1]}")
grid.add_row("[bold]Total params[/]", f"{parameters[2]}")
grid.add_row("[bold]Total estimated model params size (MB)[/]", f"{parameters[3]}")
grid.add_row("[bold]Modules in train mode[/]", f"{total_training_modes['train']}")
grid.add_row("[bold]Modules in eval mode[/]", f"{total_training_modes['eval']}")
console.print(grid)
class ImageLogger(Callback):
def on_validation_batch_end(self, trainer: Trainer, pl_module, outputs, batch, batch_idx) -> None:
if batch_idx != 0:
return
batch_size, images, targets, rev_tensor, img_paths = batch
gt_boxes = targets[0] if targets.ndim == 3 else targets
pred_boxes = outputs[0] if isinstance(outputs, list) else outputs
images = [images[0]]
step = trainer.current_epoch
for logger in trainer.loggers:
if isinstance(logger, WandbLogger):
logger.log_image("Input Image", images, step=step)
logger.log_image("Ground Truth", images, step=step, boxes=[log_bbox(gt_boxes)])
logger.log_image("Prediction", images, step=step, boxes=[log_bbox(pred_boxes)])
def setup_logger(logger_name, quite=False):
class EmojiFormatter(logging.Formatter):
def format(self, record, emoji=":high_voltage:"):
return f"{emoji} {super().format(record)}"
rich_handler = RichHandler(markup=True)
rich_handler.setFormatter(EmojiFormatter("%(message)s"))
rich_logger = logging.getLogger(logger_name)
if rich_logger:
rich_logger.handlers.clear()
rich_logger.addHandler(rich_handler)
if quite:
rich_logger.setLevel(logging.ERROR)
coco_logger = logging.getLogger("faster_coco_eval.core.cocoeval")
coco_logger.setLevel(logging.ERROR)
def setup(cfg: Config):
quite = hasattr(cfg, "quite")
setup_logger("lightning.fabric", quite=quite)
setup_logger("lightning.pytorch", quite=quite)
def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=True, silent=False):
if silent:
return
for line in string.split("\n"):
logger.info(Text.from_ansi(":globe_with_meridians: " + line))
wandb.errors.term._log = custom_wandb_log
save_path = validate_log_directory(cfg, cfg.name)
progress, loggers = [], []
if hasattr(cfg.task, "ema") and cfg.task.ema.enable:
progress.append(EMA(cfg.task.ema.decay))
if quite:
logger.setLevel(logging.ERROR)
return progress, loggers, save_path
progress.append(YOLORichProgressBar())
progress.append(YOLORichModelSummary())
progress.append(ImageLogger())
if cfg.use_tensorboard:
loggers.append(TensorBoardLogger(log_graph="all", save_dir=save_path))
if cfg.use_wandb:
loggers.append(WandbLogger(project="YOLO", name=cfg.name, save_dir=save_path, id=None))
return progress, loggers, save_path
def log_model_structure(model: Union[ModuleList, YOLOLayer, YOLO]):
if isinstance(model, YOLO):
model = model.model
console = Console()
table = Table(title="Model Layers")
table.add_column("Index", justify="center")
table.add_column("Layer Type", justify="center")
table.add_column("Tags", justify="center")
table.add_column("Params", justify="right")
table.add_column("Channels (IN->OUT)", justify="center")
for idx, layer in enumerate(model, start=1):
layer_param = sum(x.numel() for x in layer.parameters()) # number parameters
in_channels, out_channels = getattr(layer, "in_c", None), getattr(layer, "out_c", None)
if in_channels and out_channels:
if isinstance(in_channels, (list, ListConfig)):
in_channels = "M"
if isinstance(out_channels, (list, ListConfig)):
out_channels = "M"
channels = f"{str(in_channels): >4} -> {str(out_channels): >4}"
else:
channels = "-"
table.add_row(str(idx), layer.layer_type, layer.tags, f"{layer_param:,}", channels)
console.print(table)
@rank_zero_only
def validate_log_directory(cfg: Config, exp_name: str) -> Path:
base_path = Path(cfg.out_path, cfg.task.task)
save_path = base_path / exp_name
if not cfg.exist_ok:
index = 1
old_exp_name = exp_name
while save_path.is_dir():
exp_name = f"{old_exp_name}{index}"
save_path = base_path / exp_name
index += 1
if index > 1:
logger.opt(colors=True).warning(
f"π Experiment directory exists! Changed <red>{old_exp_name}</> to <green>{exp_name}</>"
)
save_path.mkdir(parents=True, exist_ok=True)
if not getattr(cfg, "quite", False):
logger.info(f"π Created log folder: [blue b u]{save_path}[/]")
logger.addHandler(FileHandler(save_path / "output.log"))
return save_path
def log_bbox(
bboxes: Tensor, class_list: Optional[List[str]] = None, image_size: Tuple[int, int] = (640, 640)
) -> List[dict]:
"""
Convert bounding boxes tensor to a list of dictionaries for logging, normalized by the image size.
Args:
bboxes (Tensor): Bounding boxes with shape (N, 5) or (N, 6), where each box is [class_id, x_min, y_min, x_max, y_max, (confidence)].
class_list (Optional[List[str]]): List of class names. Defaults to None.
image_size (Tuple[int, int]): The size of the image, used for normalization. Defaults to (640, 640).
Returns:
List[dict]: List of dictionaries containing normalized bounding box information.
"""
bbox_list = []
scale_tensor = torch.Tensor([1, *image_size, *image_size]).to(bboxes.device)
normalized_bboxes = bboxes[:, :5] / scale_tensor
for bbox in normalized_bboxes:
class_id, x_min, y_min, x_max, y_max, *conf = [float(val) for val in bbox]
if class_id == -1:
break
bbox_entry = {
"position": {"minX": x_min, "maxX": x_max, "minY": y_min, "maxY": y_max},
"class_id": int(class_id),
}
if class_list:
bbox_entry["box_caption"] = class_list[int(class_id)]
if conf:
bbox_entry["scores"] = {"confidence": conf[0]}
bbox_list.append(bbox_entry)
return {"predictions": {"box_data": bbox_list}}
|