✨ [New] wandb, progress class for handle proccess
Browse files- requirements.txt +3 -1
- yolo/tools/log_helper.py +31 -9
- yolo/tools/trainer.py +9 -9
- yolo/utils/loss.py +9 -9
requirements.txt
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
einops
|
|
|
2 |
hydra-core
|
3 |
loguru
|
4 |
numpy
|
@@ -9,4 +10,5 @@ requests
|
|
9 |
rich
|
10 |
torch
|
11 |
torchvision
|
12 |
-
tqdm
|
|
|
|
1 |
einops
|
2 |
+
graphviz
|
3 |
hydra-core
|
4 |
loguru
|
5 |
numpy
|
|
|
10 |
rich
|
11 |
torch
|
12 |
torchvision
|
13 |
+
tqdm
|
14 |
+
wandb
|
yolo/tools/log_helper.py
CHANGED
@@ -12,32 +12,39 @@ Example:
|
|
12 |
"""
|
13 |
|
14 |
import sys
|
15 |
-
from typing import List
|
16 |
|
|
|
|
|
17 |
from loguru import logger
|
18 |
from rich.console import Console
|
19 |
from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
|
20 |
from rich.table import Table
|
|
|
21 |
|
22 |
-
from yolo.config.config import YOLOLayer
|
23 |
|
24 |
|
25 |
def custom_logger():
|
26 |
logger.remove()
|
27 |
logger.add(
|
28 |
sys.stderr,
|
29 |
-
format="<
|
30 |
)
|
31 |
|
32 |
|
33 |
class CustomProgress:
|
34 |
-
def __init__(self):
|
35 |
self.progress = Progress(
|
36 |
TextColumn("[progress.description]{task.description}"),
|
37 |
BarColumn(bar_width=None),
|
38 |
TextColumn("{task.completed}/{task.total}"),
|
39 |
TimeRemainingColumn(),
|
40 |
)
|
|
|
|
|
|
|
|
|
41 |
|
42 |
def start_train(self, num_epochs: int):
|
43 |
self.task_epoch = self.progress.add_task("[cyan]Epochs", total=num_epochs)
|
@@ -45,19 +52,34 @@ class CustomProgress:
|
|
45 |
def one_epoch(self):
|
46 |
self.progress.update(self.task_epoch, advance=1)
|
47 |
|
|
|
|
|
|
|
48 |
def start_batch(self, num_batches):
|
49 |
self.batch_task = self.progress.add_task("[green]Batches", total=num_batches)
|
50 |
|
51 |
-
def one_batch(self,
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
def finish_batch(self):
|
58 |
self.progress.remove_task(self.batch_task)
|
59 |
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
def log_model(model: List[YOLOLayer]):
|
62 |
console = Console()
|
63 |
table = Table(title="Model Layers")
|
|
|
12 |
"""
|
13 |
|
14 |
import sys
|
15 |
+
from typing import Dict, List
|
16 |
|
17 |
+
import wandb
|
18 |
+
import wandb.errors
|
19 |
from loguru import logger
|
20 |
from rich.console import Console
|
21 |
from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
|
22 |
from rich.table import Table
|
23 |
+
from torch import Tensor
|
24 |
|
25 |
+
from yolo.config.config import Config, YOLOLayer
|
26 |
|
27 |
|
28 |
def custom_logger():
|
29 |
logger.remove()
|
30 |
logger.add(
|
31 |
sys.stderr,
|
32 |
+
format="<fg #003385>[{time:MM/DD HH:mm:ss}]</fg #003385><level>{level: ^8}</level>| <level>{message}</level>",
|
33 |
)
|
34 |
|
35 |
|
36 |
class CustomProgress:
|
37 |
+
def __init__(self, cfg: Config, use_wandb: bool = False):
|
38 |
self.progress = Progress(
|
39 |
TextColumn("[progress.description]{task.description}"),
|
40 |
BarColumn(bar_width=None),
|
41 |
TextColumn("{task.completed}/{task.total}"),
|
42 |
TimeRemainingColumn(),
|
43 |
)
|
44 |
+
self.use_wandb = use_wandb
|
45 |
+
if self.use_wandb:
|
46 |
+
wandb.errors.term._log = custom_wandb_log
|
47 |
+
self.wandb = wandb.init(project="YOLO", resume="allow", mode="online", dir="runs", name=cfg.name)
|
48 |
|
49 |
def start_train(self, num_epochs: int):
|
50 |
self.task_epoch = self.progress.add_task("[cyan]Epochs", total=num_epochs)
|
|
|
52 |
def one_epoch(self):
|
53 |
self.progress.update(self.task_epoch, advance=1)
|
54 |
|
55 |
+
def finish_epoch(self):
|
56 |
+
self.wandb.finish()
|
57 |
+
|
58 |
def start_batch(self, num_batches):
|
59 |
self.batch_task = self.progress.add_task("[green]Batches", total=num_batches)
|
60 |
|
61 |
+
def one_batch(self, loss_dict: Dict[str, Tensor]):
|
62 |
+
if self.use_wandb:
|
63 |
+
for loss_name, loss_value in loss_dict.items():
|
64 |
+
self.wandb.log({f"Loss/{loss_name}": loss_value})
|
65 |
+
|
66 |
+
loss_str = "Loss"
|
67 |
+
for loss_name, loss_val in loss_dict.items():
|
68 |
+
loss_str += f" {loss_name[:-4]}: {loss_val:.2f} |"
|
69 |
+
|
70 |
+
self.progress.update(self.batch_task, advance=1, description=f"[green]Batches [white]{loss_str}")
|
71 |
|
72 |
def finish_batch(self):
|
73 |
self.progress.remove_task(self.batch_task)
|
74 |
|
75 |
|
76 |
+
def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=True, silent=False):
|
77 |
+
if silent:
|
78 |
+
return
|
79 |
+
for line in string.split("\n"):
|
80 |
+
logger.opt(raw=not newline).info("🌐 " + line)
|
81 |
+
|
82 |
+
|
83 |
def log_model(model: List[YOLOLayer]):
|
84 |
console = Console()
|
85 |
table = Table(title="Model Layers")
|
yolo/tools/trainer.py
CHANGED
@@ -21,6 +21,7 @@ class Trainer:
|
|
21 |
self.optimizer = get_optimizer(model.parameters(), train_cfg.optimizer)
|
22 |
self.scheduler = get_scheduler(self.optimizer, train_cfg.scheduler)
|
23 |
self.loss_fn = get_loss_function(cfg)
|
|
|
24 |
|
25 |
if getattr(train_cfg.ema, "enabled", False):
|
26 |
self.ema = EMA(model, decay=train_cfg.ema.decay)
|
@@ -42,21 +43,21 @@ class Trainer:
|
|
42 |
|
43 |
return loss.item(), loss_item
|
44 |
|
45 |
-
def train_one_epoch(self, dataloader
|
46 |
self.model.train()
|
47 |
total_loss = 0
|
48 |
-
progress.start_batch(len(dataloader))
|
49 |
|
50 |
for data, targets in dataloader:
|
51 |
loss, loss_each = self.train_one_batch(data, targets)
|
52 |
|
53 |
total_loss += loss
|
54 |
-
progress.one_batch(loss_each)
|
55 |
|
56 |
if self.scheduler:
|
57 |
self.scheduler.step()
|
58 |
|
59 |
-
progress.finish_batch()
|
60 |
return total_loss / len(dataloader)
|
61 |
|
62 |
def save_checkpoint(self, epoch: int, filename="checkpoint.pt"):
|
@@ -73,14 +74,13 @@ class Trainer:
|
|
73 |
|
74 |
def train(self, dataloader, num_epochs):
|
75 |
logger.info("🚄 Start Training!")
|
76 |
-
progress = CustomProgress()
|
77 |
|
78 |
-
with progress.progress:
|
79 |
-
progress.start_train(num_epochs)
|
80 |
for epoch in range(num_epochs):
|
81 |
|
82 |
-
epoch_loss = self.train_one_epoch(dataloader, progress)
|
83 |
-
progress.one_epoch()
|
84 |
|
85 |
logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
|
86 |
if (epoch + 1) % 5 == 0:
|
|
|
21 |
self.optimizer = get_optimizer(model.parameters(), train_cfg.optimizer)
|
22 |
self.scheduler = get_scheduler(self.optimizer, train_cfg.scheduler)
|
23 |
self.loss_fn = get_loss_function(cfg)
|
24 |
+
self.progress = CustomProgress(cfg, use_wandb=True)
|
25 |
|
26 |
if getattr(train_cfg.ema, "enabled", False):
|
27 |
self.ema = EMA(model, decay=train_cfg.ema.decay)
|
|
|
43 |
|
44 |
return loss.item(), loss_item
|
45 |
|
46 |
+
def train_one_epoch(self, dataloader):
|
47 |
self.model.train()
|
48 |
total_loss = 0
|
49 |
+
self.progress.start_batch(len(dataloader))
|
50 |
|
51 |
for data, targets in dataloader:
|
52 |
loss, loss_each = self.train_one_batch(data, targets)
|
53 |
|
54 |
total_loss += loss
|
55 |
+
self.progress.one_batch(loss_each)
|
56 |
|
57 |
if self.scheduler:
|
58 |
self.scheduler.step()
|
59 |
|
60 |
+
self.progress.finish_batch()
|
61 |
return total_loss / len(dataloader)
|
62 |
|
63 |
def save_checkpoint(self, epoch: int, filename="checkpoint.pt"):
|
|
|
74 |
|
75 |
def train(self, dataloader, num_epochs):
|
76 |
logger.info("🚄 Start Training!")
|
|
|
77 |
|
78 |
+
with self.progress.progress:
|
79 |
+
self.progress.start_train(num_epochs)
|
80 |
for epoch in range(num_epochs):
|
81 |
|
82 |
+
epoch_loss = self.train_one_epoch(dataloader, self.progress)
|
83 |
+
self.progress.one_epoch()
|
84 |
|
85 |
logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
|
86 |
if (epoch + 1) % 5 == 0:
|
yolo/utils/loss.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
-
import
|
2 |
-
from typing import Any, List, Tuple
|
3 |
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
@@ -169,7 +168,7 @@ class DualLoss:
|
|
169 |
self.dfl_rate = cfg.hyper.train.loss.objective["DFLoss"]
|
170 |
self.cls_rate = cfg.hyper.train.loss.objective["BCELoss"]
|
171 |
|
172 |
-
def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor,
|
173 |
targets[:, :, 1:] = targets[:, :, 1:] * self.loss.scale_up
|
174 |
|
175 |
# TODO: Need Refactor this region, make it flexible!
|
@@ -177,12 +176,13 @@ class DualLoss:
|
|
177 |
aux_iou, aux_dfl, aux_cls = self.loss(predicts[0], targets)
|
178 |
main_iou, main_dfl, main_cls = self.loss(predicts[1], targets)
|
179 |
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
|
|
186 |
|
187 |
|
188 |
def get_loss_function(cfg: Config) -> YOLOLoss:
|
|
|
1 |
+
from typing import Any, Dict, List, Tuple
|
|
|
2 |
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
|
|
168 |
self.dfl_rate = cfg.hyper.train.loss.objective["DFLoss"]
|
169 |
self.cls_rate = cfg.hyper.train.loss.objective["BCELoss"]
|
170 |
|
171 |
+
def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]:
|
172 |
targets[:, :, 1:] = targets[:, :, 1:] * self.loss.scale_up
|
173 |
|
174 |
# TODO: Need Refactor this region, make it flexible!
|
|
|
176 |
aux_iou, aux_dfl, aux_cls = self.loss(predicts[0], targets)
|
177 |
main_iou, main_dfl, main_cls = self.loss(predicts[1], targets)
|
178 |
|
179 |
+
loss_dict = {
|
180 |
+
"BoxLoss": self.iou_rate * (aux_iou * self.aux_rate + main_iou),
|
181 |
+
"DFLoss": self.dfl_rate * (aux_dfl * self.aux_rate + main_dfl),
|
182 |
+
"BCELoss": self.cls_rate * (aux_cls * self.aux_rate + main_cls),
|
183 |
+
}
|
184 |
+
loss_sum = sum(list(loss_dict.values())) / len(loss_dict)
|
185 |
+
return loss_sum, loss_dict
|
186 |
|
187 |
|
188 |
def get_loss_function(cfg: Config) -> YOLOLoss:
|