File size: 5,361 Bytes
d5ba31a
a7ef999
 
d5ba31a
 
a7ef999
d5ba31a
 
 
 
 
 
 
16c6705
d5ba31a
f2370d7
31cab2b
f2370d7
7d976be
d5ba31a
0f9ffa2
6e46676
0f9ffa2
f2370d7
9c42a35
0f9ffa2
b5fa3f1
d5ba31a
 
7e84a57
d5ba31a
7e84a57
 
d5ba31a
 
16c6705
 
d5ba31a
0f9ffa2
 
d58a9b6
 
afa32b4
 
 
d58a9b6
 
6e46676
 
 
9c42a35
6e46676
 
2275731
 
d58a9b6
f2370d7
 
16c6705
afa32b4
16c6705
6e46676
 
6aabc6c
6e46676
9dfa178
9c42a35
16c6705
 
 
 
 
6e46676
 
3441a79
9dfa178
3441a79
9dfa178
f2370d7
 
 
 
6aabc6c
f2370d7
6aabc6c
f2370d7
 
9c42a35
6e46676
16c6705
6e46676
16c6705
 
 
6e46676
 
f2370d7
 
 
 
16c6705
f2370d7
 
dcceddd
0f9ffa2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16c6705
 
b5fa3f1
 
16c6705
 
b5fa3f1
16c6705
 
 
 
 
 
 
 
 
 
 
 
 
057b810
16c6705
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""
Module for initializing logging tools used in machine learning and data processing.
Supports integration with Weights & Biases (wandb), Loguru, TensorBoard, and other
logging frameworks as needed.

This setup ensures consistent logging across various platforms, facilitating
effective monitoring and debugging.

Example:
    from tools.logger import custom_logger
    custom_logger()
"""

import os
import sys
from typing import Dict, List

import wandb
import wandb.errors.term
from loguru import logger
from rich.console import Console
from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
from rich.table import Table
from torch import Tensor
from torch.optim import Optimizer

from yolo.config.config import Config, YOLOLayer


def custom_logger(quite: bool = False):
    logger.remove()
    if quite:
        return
    logger.add(
        sys.stderr,
        colorize=True,
        format="<fg #003385>[{time:MM/DD HH:mm:ss}]</> <level>{level: ^8}</level>| <level>{message}</level>",
    )


class ProgressLogger:
    def __init__(self, cfg: Config, exp_name: str):
        local_rank = int(os.getenv("LOCAL_RANK", "0"))
        self.quite_mode = local_rank or getattr(cfg, "quite", False)
        custom_logger(self.quite_mode)
        self.save_path = validate_log_directory(cfg, exp_name=cfg.name)

        self.progress = Progress(
            TextColumn("[progress.description]{task.description}"),
            BarColumn(bar_width=None),
            TextColumn("{task.completed:.0f}/{task.total:.0f}"),
            TimeRemainingColumn(),
        )
        self.progress.start()

        self.use_wandb = cfg.use_wandb
        if self.use_wandb:
            wandb.errors.term._log = custom_wandb_log
            self.wandb = wandb.init(
                project="YOLO", resume="allow", mode="online", dir=self.save_path, id=None, name=exp_name
            )

    def start_train(self, num_epochs: int):
        self.task_epoch = self.progress.add_task("[cyan]Epochs  [white]| Loss | Box  | DFL  | BCE  |", total=num_epochs)

    def start_one_epoch(self, num_batches: int, optimizer: Optimizer = None, epoch_idx: int = None):
        self.num_batches = num_batches
        if self.use_wandb:
            lr_values = [params["lr"] for params in optimizer.param_groups]
            lr_names = ["bias", "norm", "conv"]
            for lr_name, lr_value in zip(lr_names, lr_values):
                self.wandb.log({f"Learning Rate/{lr_name}": lr_value}, step=epoch_idx)
        self.batch_task = self.progress.add_task("[green]Batches", total=num_batches)

    def one_batch(self, loss_dict: Dict[str, Tensor] = None):
        if loss_dict is None:
            self.progress.update(self.batch_task, advance=1, description=f"[green]Validating")
            return
        if self.use_wandb:
            for loss_name, loss_value in loss_dict.items():
                self.wandb.log({f"Loss/{loss_name}": loss_value})

        loss_str = "| -.-- |"
        for loss_name, loss_val in loss_dict.items():
            loss_str += f" {loss_val:2.2f} |"

        self.progress.update(self.batch_task, advance=1, description=f"[green]Batches [white]{loss_str}")
        self.progress.update(self.task_epoch, advance=1 / self.num_batches)

    def finish_one_epoch(self):
        self.progress.remove_task(self.batch_task)

    def finish_train(self):
        self.wandb.finish()


def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=True, silent=False):
    if silent:
        return
    for line in string.split("\n"):
        logger.opt(raw=not newline, colors=True).info("🌐 " + line)


def log_model_structure(model: List[YOLOLayer]):
    console = Console()
    table = Table(title="Model Layers")

    table.add_column("Index", justify="center")
    table.add_column("Layer Type", justify="center")
    table.add_column("Tags", justify="center")
    table.add_column("Params", justify="right")
    table.add_column("Channels (IN->OUT)", justify="center")

    for idx, layer in enumerate(model, start=1):
        layer_param = sum(x.numel() for x in layer.parameters())  # number parameters
        in_channels, out_channels = getattr(layer, "in_c", None), getattr(layer, "out_c", None)
        if in_channels and out_channels:
            channels = f"{in_channels:4} -> {out_channels:4}"
        else:
            channels = "-"
        table.add_row(str(idx), layer.layer_type, layer.tags, f"{layer_param:,}", channels)
    console.print(table)


def validate_log_directory(cfg: Config, exp_name: str):
    base_path = os.path.join(cfg.out_path, cfg.task.task)
    save_path = os.path.join(base_path, exp_name)

    if not cfg.exist_ok:
        index = 1
        old_exp_name = exp_name
        while os.path.isdir(save_path):
            exp_name = f"{old_exp_name}{index}"
            save_path = os.path.join(base_path, exp_name)
            index += 1
        if index > 1:
            logger.opt(colors=True).warning(
                f"πŸ”€ Experiment directory exists! Changed <red>{old_exp_name}</> to <green>{exp_name}</>"
            )

    os.makedirs(save_path, exist_ok=True)
    logger.opt(colors=True).info(f"πŸ“„ Created log folder: <u><fg #808080>{save_path}</></>")
    logger.add(os.path.join(save_path, "output.log"), mode="w", backtrace=True, diagnose=True)
    return save_path