henry000 commited on
Commit
12dfccf
Β·
2 Parent(s): 1a069e1 5bf55cf

πŸ”€ [Merge] branch 'TRAIN' into TEST

Browse files
examples/example_train.py CHANGED
@@ -28,7 +28,7 @@ def main(cfg: Config):
28
  # TODO: get_device or rank, for DDP mode
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
 
31
- trainer = Trainer(model, cfg.hyper.train, device)
32
  trainer.train(dataloader, 10)
33
 
34
 
 
28
  # TODO: get_device or rank, for DDP mode
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
 
31
+ trainer = Trainer(model, cfg, device)
32
  trainer.train(dataloader, 10)
33
 
34
 
yolo/config/hyper/default.yaml CHANGED
@@ -1,5 +1,5 @@
1
  data:
2
- batch_size: 4
3
  shuffle: True
4
  num_workers: 4
5
  pin_memory: True
 
1
  data:
2
+ batch_size: 8
3
  shuffle: True
4
  num_workers: 4
5
  pin_memory: True
yolo/tools/trainer.py CHANGED
@@ -1,48 +1,63 @@
1
  import torch
2
  from loguru import logger
 
 
3
  from tqdm import tqdm
4
 
5
- from yolo.config.config import TrainConfig
6
  from yolo.model.yolo import YOLO
7
  from yolo.tools.model_helper import EMA, get_optimizer, get_scheduler
8
  from yolo.utils.loss import get_loss_function
9
 
10
 
11
  class Trainer:
12
- def __init__(self, model: YOLO, train_cfg: TrainConfig, device):
 
 
13
  self.model = model.to(device)
14
  self.device = device
15
  self.optimizer = get_optimizer(model.parameters(), train_cfg.optimizer)
16
  self.scheduler = get_scheduler(self.optimizer, train_cfg.scheduler)
17
- self.loss_fn = get_loss_function()
18
 
19
  if train_cfg.ema.get("enabled", False):
20
  self.ema = EMA(model, decay=train_cfg.ema.decay)
21
  else:
22
  self.ema = None
 
23
 
24
- def train_one_batch(self, data, targets):
25
  data, targets = data.to(self.device), targets.to(self.device)
26
  self.optimizer.zero_grad()
27
- outputs = self.model(data)
28
- loss = self.loss_fn(outputs, targets)
29
- loss.backward()
30
- self.optimizer.step()
 
 
 
 
 
 
 
 
31
  if self.ema:
32
  self.ema.update()
 
33
  return loss.item()
34
 
35
  def train_one_epoch(self, dataloader):
36
  self.model.train()
37
  total_loss = 0
38
- for data, targets in tqdm(dataloader, desc="Training"):
39
- loss = self.train_one_batch(data, targets)
40
- total_loss += loss
41
- if self.scheduler:
42
- self.scheduler.step()
 
43
  return total_loss / len(dataloader)
44
 
45
- def save_checkpoint(self, epoch, filename="checkpoint.pt"):
46
  checkpoint = {
47
  "epoch": epoch,
48
  "model_state_dict": self.model.state_dict(),
 
1
  import torch
2
  from loguru import logger
3
+ from torch import Tensor
4
+ from torch.cuda.amp import GradScaler, autocast
5
  from tqdm import tqdm
6
 
7
+ from yolo.config.config import Config, TrainConfig
8
  from yolo.model.yolo import YOLO
9
  from yolo.tools.model_helper import EMA, get_optimizer, get_scheduler
10
  from yolo.utils.loss import get_loss_function
11
 
12
 
13
  class Trainer:
14
+ def __init__(self, model: YOLO, cfg: Config, device):
15
+ train_cfg: TrainConfig = cfg.hyper.train
16
+
17
  self.model = model.to(device)
18
  self.device = device
19
  self.optimizer = get_optimizer(model.parameters(), train_cfg.optimizer)
20
  self.scheduler = get_scheduler(self.optimizer, train_cfg.scheduler)
21
+ self.loss_fn = get_loss_function(cfg)
22
 
23
  if train_cfg.ema.get("enabled", False):
24
  self.ema = EMA(model, decay=train_cfg.ema.decay)
25
  else:
26
  self.ema = None
27
+ self.scaler = GradScaler()
28
 
29
+ def train_one_batch(self, data: Tensor, targets: Tensor, progress: tqdm):
30
  data, targets = data.to(self.device), targets.to(self.device)
31
  self.optimizer.zero_grad()
32
+
33
+ with autocast():
34
+ outputs = self.model(data)
35
+ loss, loss_item = self.loss_fn(outputs, targets)
36
+ loss_iou, loss_dfl, loss_cls = loss_item
37
+
38
+ progress.set_description(f"Loss IoU: {loss_iou:.5f}, DFL: {loss_dfl:.5f}, CLS: {loss_cls:.5f}")
39
+
40
+ self.scaler.scale(loss).backward()
41
+ self.scaler.step(self.optimizer)
42
+ self.scaler.update()
43
+
44
  if self.ema:
45
  self.ema.update()
46
+
47
  return loss.item()
48
 
49
  def train_one_epoch(self, dataloader):
50
  self.model.train()
51
  total_loss = 0
52
+ with tqdm(dataloader, desc="Training") as progress:
53
+ for data, targets in progress:
54
+ loss = self.train_one_batch(data, targets, progress)
55
+ total_loss += loss
56
+ if self.scheduler:
57
+ self.scheduler.step()
58
  return total_loss / len(dataloader)
59
 
60
+ def save_checkpoint(self, epoch: int, filename="checkpoint.pt"):
61
  checkpoint = {
62
  "epoch": epoch,
63
  "model_state_dict": self.model.state_dict(),
yolo/utils/loss.py CHANGED
@@ -17,10 +17,6 @@ from yolo.tools.bbox_helper import (
17
  )
18
 
19
 
20
- def get_loss_function(*args, **kwargs):
21
- raise NotImplementedError
22
-
23
-
24
  class BCELoss(nn.Module):
25
  def __init__(self) -> None:
26
  super().__init__()
@@ -144,7 +140,9 @@ class YOLOLoss:
144
  # Batch_Size x (Anchor + Class) x H x W
145
  # TODO: check datatype, why targets has a little bit error with origin version
146
  predicts, predicts_anc = self.parse_predicts(predicts[0])
147
- targets = self.parse_targets(targets, batch_size=predicts.size(0))
 
 
148
 
149
  align_targets, valid_masks = self.matcher(targets, predicts)
150
  # calculate loss between with instance and predict
@@ -162,5 +160,11 @@ class YOLOLoss:
162
  ## -- DFL -- ##
163
  loss_dfl = self.dfl(predicts_anc, targets_bbox, valid_masks, box_norm, cls_norm)
164
 
165
- logger.info("Loss IoU: {:.5f}, DFL: {:.5f}, CLS: {:.5f}", loss_iou, loss_dfl, loss_cls)
166
- return loss_iou, loss_dfl, loss_cls
 
 
 
 
 
 
 
17
  )
18
 
19
 
 
 
 
 
20
  class BCELoss(nn.Module):
21
  def __init__(self) -> None:
22
  super().__init__()
 
140
  # Batch_Size x (Anchor + Class) x H x W
141
  # TODO: check datatype, why targets has a little bit error with origin version
142
  predicts, predicts_anc = self.parse_predicts(predicts[0])
143
+ # TODO: Refactor this operator
144
+ # targets = self.parse_targets(targets, batch_size=predicts.size(0))
145
+ targets[:, :, 1:] = targets[:, :, 1:] * self.scale_up
146
 
147
  align_targets, valid_masks = self.matcher(targets, predicts)
148
  # calculate loss between with instance and predict
 
160
  ## -- DFL -- ##
161
  loss_dfl = self.dfl(predicts_anc, targets_bbox, valid_masks, box_norm, cls_norm)
162
 
163
+ loss_sum = loss_iou * 0.5 + loss_dfl * 1.5 + loss_cls * 0.5
164
+ return loss_sum, (loss_iou, loss_dfl, loss_cls)
165
+
166
+
167
+ def get_loss_function(cfg: Config) -> YOLOLoss:
168
+ loss_function = YOLOLoss(cfg)
169
+ logger.info("βœ… Success load loss function")
170
+ return loss_function