henry000 commited on
Commit
6e46676
Β·
1 Parent(s): 3fa2be7

πŸ’¬ [Add] Progress class, handle progress bar

Browse files
Files changed (2) hide show
  1. yolo/tools/log_helper.py +29 -0
  2. yolo/tools/trainer.py +30 -20
yolo/tools/log_helper.py CHANGED
@@ -16,6 +16,7 @@ from typing import List
16
 
17
  from loguru import logger
18
  from rich.console import Console
 
19
  from rich.table import Table
20
 
21
  from yolo.config.config import YOLOLayer
@@ -29,6 +30,34 @@ def custom_logger():
29
  )
30
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def log_model(model: List[YOLOLayer]):
33
  console = Console()
34
  table = Table(title="Model Layers")
 
16
 
17
  from loguru import logger
18
  from rich.console import Console
19
+ from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
20
  from rich.table import Table
21
 
22
  from yolo.config.config import YOLOLayer
 
30
  )
31
 
32
 
33
+ class CustomProgress:
34
+ def __init__(self):
35
+ self.progress = Progress(
36
+ TextColumn("[progress.description]{task.description}"),
37
+ BarColumn(bar_width=None),
38
+ TextColumn("{task.completed}/{task.total}"),
39
+ TimeRemainingColumn(),
40
+ )
41
+
42
+ def start_train(self, num_epochs: int):
43
+ self.task_epoch = self.progress.add_task("[cyan]Epochs", total=num_epochs)
44
+
45
+ def one_epoch(self):
46
+ self.progress.update(self.task_epoch, advance=1)
47
+
48
+ def start_batch(self, num_batches):
49
+ self.batch_task = self.progress.add_task("[green]Batches", total=num_batches)
50
+
51
+ def one_batch(self, loss_each):
52
+ loss_iou, loss_dfl, loss_cls = loss_each
53
+ # TODO: make it flexible? if need add more loss
54
+ loss_str = f"Loss IoU: {loss_iou:.3f}, DFL: {loss_dfl:.3f}, CLS: {loss_cls:.3f}"
55
+ self.progress.update(self.batch_task, advance=1, description=f"[green]Batches {loss_str}")
56
+
57
+ def finish_batch(self):
58
+ self.progress.remove_task(self.batch_task)
59
+
60
+
61
  def log_model(model: List[YOLOLayer]):
62
  console = Console()
63
  table = Table(title="Model Layers")
yolo/tools/trainer.py CHANGED
@@ -1,11 +1,13 @@
1
  import torch
2
  from loguru import logger
3
  from torch import Tensor
 
 
4
  from torch.cuda.amp import GradScaler, autocast
5
- from tqdm import tqdm
6
 
7
  from yolo.config.config import Config, TrainConfig
8
  from yolo.model.yolo import YOLO
 
9
  from yolo.tools.model_helper import EMA, get_optimizer, get_scheduler
10
  from yolo.utils.loss import get_loss_function
11
 
@@ -26,16 +28,13 @@ class Trainer:
26
  self.ema = None
27
  self.scaler = GradScaler()
28
 
29
- def train_one_batch(self, data: Tensor, targets: Tensor, progress: tqdm):
30
  data, targets = data.to(self.device), targets.to(self.device)
31
  self.optimizer.zero_grad()
32
 
33
  with autocast():
34
  outputs = self.model(data)
35
  loss, loss_item = self.loss_fn(outputs, targets)
36
- loss_iou, loss_dfl, loss_cls = loss_item
37
-
38
- progress.set_description(f"Loss IoU: {loss_iou:.5f}, DFL: {loss_dfl:.5f}, CLS: {loss_cls:.5f}")
39
 
40
  self.scaler.scale(loss).backward()
41
  self.scaler.step(self.optimizer)
@@ -43,17 +42,21 @@ class Trainer:
43
 
44
  return loss.item(), loss_item
45
 
46
- return loss.item()
47
-
48
- def train_one_epoch(self, dataloader):
49
  self.model.train()
50
  total_loss = 0
51
- with tqdm(dataloader, desc="Training") as progress:
52
- for data, targets in progress:
53
- loss = self.train_one_batch(data, targets, progress)
54
- total_loss += loss
55
- if self.scheduler:
56
- self.scheduler.step()
 
 
 
 
 
 
57
  return total_loss / len(dataloader)
58
 
59
  def save_checkpoint(self, epoch: int, filename="checkpoint.pt"):
@@ -69,9 +72,16 @@ class Trainer:
69
  torch.save(checkpoint, filename)
70
 
71
  def train(self, dataloader, num_epochs):
72
- logger.info("start train")
73
- for epoch in range(num_epochs):
74
- epoch_loss = self.train_one_epoch(dataloader)
75
- logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
76
- if (epoch + 1) % 5 == 0:
77
- self.save_checkpoint(epoch, f"checkpoint_epoch_{epoch+1}.pth")
 
 
 
 
 
 
 
 
1
  import torch
2
  from loguru import logger
3
  from torch import Tensor
4
+
5
+ # TODO: We may can't use CUDA?
6
  from torch.cuda.amp import GradScaler, autocast
 
7
 
8
  from yolo.config.config import Config, TrainConfig
9
  from yolo.model.yolo import YOLO
10
+ from yolo.tools.log_helper import CustomProgress
11
  from yolo.tools.model_helper import EMA, get_optimizer, get_scheduler
12
  from yolo.utils.loss import get_loss_function
13
 
 
28
  self.ema = None
29
  self.scaler = GradScaler()
30
 
31
+ def train_one_batch(self, data: Tensor, targets: Tensor):
32
  data, targets = data.to(self.device), targets.to(self.device)
33
  self.optimizer.zero_grad()
34
 
35
  with autocast():
36
  outputs = self.model(data)
37
  loss, loss_item = self.loss_fn(outputs, targets)
 
 
 
38
 
39
  self.scaler.scale(loss).backward()
40
  self.scaler.step(self.optimizer)
 
42
 
43
  return loss.item(), loss_item
44
 
45
+ def train_one_epoch(self, dataloader, progress: CustomProgress):
 
 
46
  self.model.train()
47
  total_loss = 0
48
+ progress.start_batch(len(dataloader))
49
+
50
+ for data, targets in dataloader:
51
+ loss, loss_each = self.train_one_batch(data, targets)
52
+
53
+ total_loss += loss
54
+ progress.one_batch(loss_each)
55
+
56
+ if self.scheduler:
57
+ self.scheduler.step()
58
+
59
+ progress.finish_batch()
60
  return total_loss / len(dataloader)
61
 
62
  def save_checkpoint(self, epoch: int, filename="checkpoint.pt"):
 
72
  torch.save(checkpoint, filename)
73
 
74
  def train(self, dataloader, num_epochs):
75
+ logger.info("πŸš„ Start Training!")
76
+ progress = CustomProgress()
77
+
78
+ with progress.progress:
79
+ progress.start_train(num_epochs)
80
+ for epoch in range(num_epochs):
81
+
82
+ epoch_loss = self.train_one_epoch(dataloader, progress)
83
+ progress.one_epoch()
84
+
85
+ logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
86
+ if (epoch + 1) % 5 == 0:
87
+ self.save_checkpoint(epoch, f"checkpoint_epoch_{epoch+1}.pth")