henry000 commited on
Commit
669657d
Β·
1 Parent(s): f0fdf9a

πŸš€ [Add] torch auto mixed precision

Browse files
Files changed (1) hide show
  1. yolo/tools/trainer.py +25 -12
yolo/tools/trainer.py CHANGED
@@ -1,8 +1,10 @@
1
  import torch
2
  from loguru import logger
 
 
3
  from tqdm import tqdm
4
 
5
- from yolo.config.config import TrainConfig
6
  from yolo.model.yolo import YOLO
7
  from yolo.tools.model_helper import EMA, get_optimizer, get_scheduler
8
  from yolo.utils.loss import get_loss_function
@@ -22,29 +24,40 @@ class Trainer:
22
  self.ema = EMA(model, decay=train_cfg.ema.decay)
23
  else:
24
  self.ema = None
 
25
 
26
- def train_one_batch(self, data, targets):
27
  data, targets = data.to(self.device), targets.to(self.device)
28
  self.optimizer.zero_grad()
29
- outputs = self.model(data)
30
- loss = self.loss_fn(outputs, targets)
31
- loss.backward()
32
- self.optimizer.step()
 
 
 
 
 
 
 
 
33
  if self.ema:
34
  self.ema.update()
 
35
  return loss.item()
36
 
37
  def train_one_epoch(self, dataloader):
38
  self.model.train()
39
  total_loss = 0
40
- for data, targets in tqdm(dataloader, desc="Training"):
41
- loss = self.train_one_batch(data, targets)
42
- total_loss += loss
43
- if self.scheduler:
44
- self.scheduler.step()
 
45
  return total_loss / len(dataloader)
46
 
47
- def save_checkpoint(self, epoch, filename="checkpoint.pt"):
48
  checkpoint = {
49
  "epoch": epoch,
50
  "model_state_dict": self.model.state_dict(),
 
1
  import torch
2
  from loguru import logger
3
+ from torch import Tensor
4
+ from torch.cuda.amp import GradScaler, autocast
5
  from tqdm import tqdm
6
 
7
+ from yolo.config.config import Config, TrainConfig
8
  from yolo.model.yolo import YOLO
9
  from yolo.tools.model_helper import EMA, get_optimizer, get_scheduler
10
  from yolo.utils.loss import get_loss_function
 
24
  self.ema = EMA(model, decay=train_cfg.ema.decay)
25
  else:
26
  self.ema = None
27
+ self.scaler = GradScaler()
28
 
29
+ def train_one_batch(self, data: Tensor, targets: Tensor, progress: tqdm):
30
  data, targets = data.to(self.device), targets.to(self.device)
31
  self.optimizer.zero_grad()
32
+
33
+ with autocast():
34
+ outputs = self.model(data)
35
+ loss, loss_item = self.loss_fn(outputs, targets)
36
+ loss_iou, loss_dfl, loss_cls = loss_item
37
+
38
+ progress.set_description(f"Loss IoU: {loss_iou:.5f}, DFL: {loss_dfl:.5f}, CLS: {loss_cls:.5f}")
39
+
40
+ self.scaler.scale(loss).backward()
41
+ self.scaler.step(self.optimizer)
42
+ self.scaler.update()
43
+
44
  if self.ema:
45
  self.ema.update()
46
+
47
  return loss.item()
48
 
49
  def train_one_epoch(self, dataloader):
50
  self.model.train()
51
  total_loss = 0
52
+ with tqdm(dataloader, desc="Training") as progress:
53
+ for data, targets in progress:
54
+ loss = self.train_one_batch(data, targets, progress)
55
+ total_loss += loss
56
+ if self.scheduler:
57
+ self.scheduler.step()
58
  return total_loss / len(dataloader)
59
 
60
+ def save_checkpoint(self, epoch: int, filename="checkpoint.pt"):
61
  checkpoint = {
62
  "epoch": epoch,
63
  "model_state_dict": self.model.state_dict(),