π¬ [Add] Progress class, handle progress bar
Browse files- yolo/tools/log_helper.py +29 -0
- yolo/tools/trainer.py +30 -20
yolo/tools/log_helper.py
CHANGED
@@ -16,6 +16,7 @@ from typing import List
|
|
16 |
|
17 |
from loguru import logger
|
18 |
from rich.console import Console
|
|
|
19 |
from rich.table import Table
|
20 |
|
21 |
from yolo.config.config import YOLOLayer
|
@@ -29,6 +30,34 @@ def custom_logger():
|
|
29 |
)
|
30 |
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
def log_model(model: List[YOLOLayer]):
|
33 |
console = Console()
|
34 |
table = Table(title="Model Layers")
|
|
|
16 |
|
17 |
from loguru import logger
|
18 |
from rich.console import Console
|
19 |
+
from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
|
20 |
from rich.table import Table
|
21 |
|
22 |
from yolo.config.config import YOLOLayer
|
|
|
30 |
)
|
31 |
|
32 |
|
33 |
+
class CustomProgress:
|
34 |
+
def __init__(self):
|
35 |
+
self.progress = Progress(
|
36 |
+
TextColumn("[progress.description]{task.description}"),
|
37 |
+
BarColumn(bar_width=None),
|
38 |
+
TextColumn("{task.completed}/{task.total}"),
|
39 |
+
TimeRemainingColumn(),
|
40 |
+
)
|
41 |
+
|
42 |
+
def start_train(self, num_epochs: int):
|
43 |
+
self.task_epoch = self.progress.add_task("[cyan]Epochs", total=num_epochs)
|
44 |
+
|
45 |
+
def one_epoch(self):
|
46 |
+
self.progress.update(self.task_epoch, advance=1)
|
47 |
+
|
48 |
+
def start_batch(self, num_batches):
|
49 |
+
self.batch_task = self.progress.add_task("[green]Batches", total=num_batches)
|
50 |
+
|
51 |
+
def one_batch(self, loss_each):
|
52 |
+
loss_iou, loss_dfl, loss_cls = loss_each
|
53 |
+
# TODO: make it flexible? if need add more loss
|
54 |
+
loss_str = f"Loss IoU: {loss_iou:.3f}, DFL: {loss_dfl:.3f}, CLS: {loss_cls:.3f}"
|
55 |
+
self.progress.update(self.batch_task, advance=1, description=f"[green]Batches {loss_str}")
|
56 |
+
|
57 |
+
def finish_batch(self):
|
58 |
+
self.progress.remove_task(self.batch_task)
|
59 |
+
|
60 |
+
|
61 |
def log_model(model: List[YOLOLayer]):
|
62 |
console = Console()
|
63 |
table = Table(title="Model Layers")
|
yolo/tools/trainer.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1 |
import torch
|
2 |
from loguru import logger
|
3 |
from torch import Tensor
|
|
|
|
|
4 |
from torch.cuda.amp import GradScaler, autocast
|
5 |
-
from tqdm import tqdm
|
6 |
|
7 |
from yolo.config.config import Config, TrainConfig
|
8 |
from yolo.model.yolo import YOLO
|
|
|
9 |
from yolo.tools.model_helper import EMA, get_optimizer, get_scheduler
|
10 |
from yolo.utils.loss import get_loss_function
|
11 |
|
@@ -26,16 +28,13 @@ class Trainer:
|
|
26 |
self.ema = None
|
27 |
self.scaler = GradScaler()
|
28 |
|
29 |
-
def train_one_batch(self, data: Tensor, targets: Tensor
|
30 |
data, targets = data.to(self.device), targets.to(self.device)
|
31 |
self.optimizer.zero_grad()
|
32 |
|
33 |
with autocast():
|
34 |
outputs = self.model(data)
|
35 |
loss, loss_item = self.loss_fn(outputs, targets)
|
36 |
-
loss_iou, loss_dfl, loss_cls = loss_item
|
37 |
-
|
38 |
-
progress.set_description(f"Loss IoU: {loss_iou:.5f}, DFL: {loss_dfl:.5f}, CLS: {loss_cls:.5f}")
|
39 |
|
40 |
self.scaler.scale(loss).backward()
|
41 |
self.scaler.step(self.optimizer)
|
@@ -43,17 +42,21 @@ class Trainer:
|
|
43 |
|
44 |
return loss.item(), loss_item
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
def train_one_epoch(self, dataloader):
|
49 |
self.model.train()
|
50 |
total_loss = 0
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
return total_loss / len(dataloader)
|
58 |
|
59 |
def save_checkpoint(self, epoch: int, filename="checkpoint.pt"):
|
@@ -69,9 +72,16 @@ class Trainer:
|
|
69 |
torch.save(checkpoint, filename)
|
70 |
|
71 |
def train(self, dataloader, num_epochs):
|
72 |
-
logger.info("
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
from loguru import logger
|
3 |
from torch import Tensor
|
4 |
+
|
5 |
+
# TODO: We may can't use CUDA?
|
6 |
from torch.cuda.amp import GradScaler, autocast
|
|
|
7 |
|
8 |
from yolo.config.config import Config, TrainConfig
|
9 |
from yolo.model.yolo import YOLO
|
10 |
+
from yolo.tools.log_helper import CustomProgress
|
11 |
from yolo.tools.model_helper import EMA, get_optimizer, get_scheduler
|
12 |
from yolo.utils.loss import get_loss_function
|
13 |
|
|
|
28 |
self.ema = None
|
29 |
self.scaler = GradScaler()
|
30 |
|
31 |
+
def train_one_batch(self, data: Tensor, targets: Tensor):
|
32 |
data, targets = data.to(self.device), targets.to(self.device)
|
33 |
self.optimizer.zero_grad()
|
34 |
|
35 |
with autocast():
|
36 |
outputs = self.model(data)
|
37 |
loss, loss_item = self.loss_fn(outputs, targets)
|
|
|
|
|
|
|
38 |
|
39 |
self.scaler.scale(loss).backward()
|
40 |
self.scaler.step(self.optimizer)
|
|
|
42 |
|
43 |
return loss.item(), loss_item
|
44 |
|
45 |
+
def train_one_epoch(self, dataloader, progress: CustomProgress):
|
|
|
|
|
46 |
self.model.train()
|
47 |
total_loss = 0
|
48 |
+
progress.start_batch(len(dataloader))
|
49 |
+
|
50 |
+
for data, targets in dataloader:
|
51 |
+
loss, loss_each = self.train_one_batch(data, targets)
|
52 |
+
|
53 |
+
total_loss += loss
|
54 |
+
progress.one_batch(loss_each)
|
55 |
+
|
56 |
+
if self.scheduler:
|
57 |
+
self.scheduler.step()
|
58 |
+
|
59 |
+
progress.finish_batch()
|
60 |
return total_loss / len(dataloader)
|
61 |
|
62 |
def save_checkpoint(self, epoch: int, filename="checkpoint.pt"):
|
|
|
72 |
torch.save(checkpoint, filename)
|
73 |
|
74 |
def train(self, dataloader, num_epochs):
|
75 |
+
logger.info("π Start Training!")
|
76 |
+
progress = CustomProgress()
|
77 |
+
|
78 |
+
with progress.progress:
|
79 |
+
progress.start_train(num_epochs)
|
80 |
+
for epoch in range(num_epochs):
|
81 |
+
|
82 |
+
epoch_loss = self.train_one_epoch(dataloader, progress)
|
83 |
+
progress.one_epoch()
|
84 |
+
|
85 |
+
logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
|
86 |
+
if (epoch + 1) % 5 == 0:
|
87 |
+
self.save_checkpoint(epoch, f"checkpoint_epoch_{epoch+1}.pth")
|