File size: 7,885 Bytes
d5ba31a
a7ef999
 
d5ba31a
 
a7ef999
d5ba31a
 
 
 
 
 
 
16c6705
67c056a
d5ba31a
999037c
fa09d11
68d5954
31cab2b
fa09d11
67c056a
f2370d7
7d976be
d5ba31a
c3b133e
999037c
 
 
 
 
 
 
 
0f9ffa2
f2370d7
68d5954
9c42a35
0f9ffa2
b5fa3f1
68d5954
999037c
d5ba31a
 
7e84a57
d5ba31a
7e84a57
 
d5ba31a
 
16c6705
 
d5ba31a
0f9ffa2
 
67c056a
 
 
 
 
 
 
 
 
 
 
 
999037c
 
afa32b4
 
 
d58a9b6
 
999037c
 
6e46676
 
9c42a35
6e46676
 
999037c
 
 
fa09d11
999037c
2275731
d58a9b6
f2370d7
 
16c6705
afa32b4
16c6705
6e46676
999037c
c3b133e
 
999037c
6e46676
c3b133e
6e46676
c3b133e
 
 
9c42a35
c3b133e
 
 
 
999037c
16c6705
c3b133e
16c6705
c3b133e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73b88fc
c3b133e
73b88fc
 
f2370d7
c3b133e
 
f2370d7
c3b133e
73b88fc
999037c
c3b133e
fa09d11
 
c3b133e
 
999037c
c3b133e
fa09d11
999037c
 
 
6e46676
16c6705
fa09d11
 
6a39ae1
 
6e46676
 
f2370d7
 
 
 
16c6705
f2370d7
 
68d5954
 
 
0f9ffa2
 
 
 
 
 
 
 
 
 
 
 
 
c3b133e
 
 
 
 
0f9ffa2
 
 
 
16c6705
 
fa09d11
 
 
16c6705
b5fa3f1
16c6705
 
fa09d11
16c6705
fa09d11
16c6705
 
 
 
 
 
b038f54
16c6705
fa09d11
16c6705
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
"""
Module for initializing logging tools used in machine learning and data processing.
Supports integration with Weights & Biases (wandb), Loguru, TensorBoard, and other
logging frameworks as needed.

This setup ensures consistent logging across various platforms, facilitating
effective monitoring and debugging.

Example:
    from tools.logger import custom_logger
    custom_logger()
"""

import os
import random
import sys
from collections import deque
from pathlib import Path
from typing import Any, Dict, Union

import numpy as np
import torch
import wandb
import wandb.errors.term
from loguru import logger
from omegaconf import ListConfig
from rich.console import Console, Group
from rich.progress import (
    BarColumn,
    Progress,
    SpinnerColumn,
    TextColumn,
    TimeRemainingColumn,
)
from rich.table import Table
from torch import Tensor
from torch.nn import ModuleList
from torch.optim import Optimizer

from yolo.config.config import Config, YOLOLayer
from yolo.model.yolo import YOLO
from yolo.utils.solver_utils import make_ap_table


def custom_logger(quite: bool = False):
    logger.remove()
    if quite:
        return
    logger.add(
        sys.stderr,
        colorize=True,
        format="<fg #003385>[{time:MM/DD HH:mm:ss}]</> <level>{level: ^8}</level>| <level>{message}</level>",
    )


# TODO: should be moved to correct position
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


class ProgressLogger(Progress):
    def __init__(self, cfg: Config, exp_name: str, *args, **kwargs):
        local_rank = int(os.getenv("LOCAL_RANK", "0"))
        self.quite_mode = local_rank or getattr(cfg, "quite", False)
        custom_logger(self.quite_mode)
        self.save_path = validate_log_directory(cfg, exp_name=cfg.name)

        progress_bar = (
            SpinnerColumn(),
            TextColumn("[progress.description]{task.description}"),
            BarColumn(bar_width=None),
            TextColumn("{task.completed:.0f}/{task.total:.0f}"),
            TimeRemainingColumn(),
        )
        self.ap_table = Table()
        # TODO: load maxlen by config files
        self.ap_past_list = deque(maxlen=5)
        self.last_result = 0
        super().__init__(*args, *progress_bar, **kwargs)

        self.use_wandb = cfg.use_wandb
        if self.use_wandb:
            wandb.errors.term._log = custom_wandb_log
            self.wandb = wandb.init(
                project="YOLO", resume="allow", mode="online", dir=self.save_path, id=None, name=exp_name
            )

    def get_renderable(self):
        renderable = Group(*self.get_renderables(), self.ap_table)
        return renderable

    def start_train(self, num_epochs: int):
        self.task_epoch = self.add_task(f"[cyan]Start Training {num_epochs} epochs", total=num_epochs)

    def start_one_epoch(
        self, num_batches: int, task: str = "Train", optimizer: Optimizer = None, epoch_idx: int = None
    ):
        self.num_batches = num_batches
        self.task = task
        if hasattr(self, "task_epoch"):
            self.update(self.task_epoch, description=f"[cyan] Preparing Data")

        if self.use_wandb and optimizer is not None:
            lr_values = [params["lr"] for params in optimizer.param_groups]
            lr_names = ["Learning Rate/bias", "Learning Rate/norm", "Learning Rate/conv"]
            for lr_name, lr_value in zip(lr_names, lr_values):
                self.wandb.log({lr_name: lr_value}, step=epoch_idx)
        self.batch_task = self.add_task(f"[green] Phase: {task}", total=num_batches)

    def one_batch(self, batch_info: Dict[str, Tensor] = None):
        epoch_descript = "[cyan]" + self.task + "[white] |"
        batch_descript = "|"
        if self.task == "Train":
            self.update(self.task_epoch, advance=1 / self.num_batches)
        for info_name, info_val in batch_info.items():
            epoch_descript += f"{info_name: ^9}|"
            batch_descript += f"   {info_val:2.2f}  |"
        self.update(self.batch_task, advance=1, description=f"[green]{self.task} [white]{batch_descript}")
        if hasattr(self, "task_epoch"):
            self.update(self.task_epoch, description=epoch_descript)

    def finish_one_epoch(self, batch_info: Dict[str, Any] = None, epoch_idx: int = -1):
        if self.task == "Train":
            prefix = "Loss/"
        elif self.task == "Validate":
            prefix = "Metrics/"
        batch_info = {f"{prefix}{key}": value for key, value in batch_info.items()}
        if self.use_wandb:
            self.wandb.log(batch_info, step=epoch_idx)
        self.remove_task(self.batch_task)

    def start_pycocotools(self):
        self.batch_task = self.add_task("[green]Run pycocotools", total=1)

    def finish_pycocotools(self, result, epoch_idx=-1):
        ap_table, ap_main = make_ap_table(result, self.ap_past_list, self.last_result, epoch_idx)
        self.last_result = np.maximum(result, self.last_result)
        self.ap_past_list.append((epoch_idx, ap_main))
        self.ap_table = ap_table

        if self.use_wandb:
            self.wandb.log({"PyCOCO/AP @ .5:.95": ap_main[2], "PyCOCO/AP @ .5": ap_main[5]})
        self.update(self.batch_task, advance=1)
        self.refresh()
        self.remove_task(self.batch_task)

    def finish_train(self):
        self.remove_task(self.task_epoch)
        self.stop()
        if self.use_wandb:
            self.wandb.finish()


def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=True, silent=False):
    if silent:
        return
    for line in string.split("\n"):
        logger.opt(raw=not newline, colors=True).info("🌐 " + line)


def log_model_structure(model: Union[ModuleList, YOLOLayer, YOLO]):
    if isinstance(model, YOLO):
        model = model.model
    console = Console()
    table = Table(title="Model Layers")

    table.add_column("Index", justify="center")
    table.add_column("Layer Type", justify="center")
    table.add_column("Tags", justify="center")
    table.add_column("Params", justify="right")
    table.add_column("Channels (IN->OUT)", justify="center")

    for idx, layer in enumerate(model, start=1):
        layer_param = sum(x.numel() for x in layer.parameters())  # number parameters
        in_channels, out_channels = getattr(layer, "in_c", None), getattr(layer, "out_c", None)
        if in_channels and out_channels:
            if isinstance(in_channels, (list, ListConfig)):
                in_channels = "M"
            if isinstance(out_channels, (list, ListConfig)):
                out_channels = "M"
            channels = f"{str(in_channels): >4} -> {str(out_channels): >4}"
        else:
            channels = "-"
        table.add_row(str(idx), layer.layer_type, layer.tags, f"{layer_param:,}", channels)
    console.print(table)


def validate_log_directory(cfg: Config, exp_name: str) -> Path:
    base_path = Path(cfg.out_path, cfg.task.task)
    save_path = base_path / exp_name

    if not cfg.exist_ok:
        index = 1
        old_exp_name = exp_name
        while save_path.is_dir():
            exp_name = f"{old_exp_name}{index}"
            save_path = base_path / exp_name
            index += 1
        if index > 1:
            logger.opt(colors=True).warning(
                f"πŸ”€ Experiment directory exists! Changed <red>{old_exp_name}</> to <green>{exp_name}</>"
            )

    save_path.mkdir(parents=True, exist_ok=True)
    logger.opt(colors=True).info(f"πŸ“„ Created log folder: <u><fg #808080>{save_path}</></>")
    logger.add(save_path / "output.log", mode="w", backtrace=True, diagnose=True)
    return save_path