File size: 3,087 Bytes
1197f7d 669657d 6e46676 669657d 1197f7d 669657d 16c6705 dcceddd 1197f7d dcceddd 16c6705 f0fdf9a 16c6705 f0fdf9a 1197f7d dcceddd 584d5bd dcceddd 1197f7d 6e85a96 dcceddd 1197f7d 669657d 1197f7d 6e46676 1197f7d 669657d b4bcccb 669657d f2370d7 1197f7d 6e46676 f2370d7 6e46676 1197f7d 669657d 1197f7d 6e46676 f2370d7 6e46676 3e08dd8 c601a4c 3e08dd8 6e46676 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import torch
from loguru import logger
from torch import Tensor
# TODO: We may can't use CUDA?
from torch.cuda.amp import GradScaler, autocast
from yolo.config.config import Config, TrainConfig
from yolo.model.yolo import get_model
from yolo.tools.loss_functions import get_loss_function
from yolo.utils.logging_utils import ProgressTracker
from yolo.utils.model_utils import (
ExponentialMovingAverage,
create_optimizer,
create_scheduler,
)
class ModelTrainer:
def __init__(self, cfg: Config, save_path: str, device):
train_cfg: TrainConfig = cfg.hyper.train
model = get_model(cfg)
self.model = model.to(device)
self.device = device
self.optimizer = create_optimizer(model, train_cfg.optimizer)
self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
self.loss_fn = get_loss_function(cfg)
self.progress = ProgressTracker(cfg, save_path, use_wandb=True)
if getattr(train_cfg.ema, "enabled", False):
self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
else:
self.ema = None
self.scaler = GradScaler()
def train_one_batch(self, data: Tensor, targets: Tensor):
data, targets = data.to(self.device), targets.to(self.device)
self.optimizer.zero_grad()
with autocast():
outputs = self.model(data)
loss, loss_item = self.loss_fn(outputs, targets)
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
return loss.item(), loss_item
def train_one_epoch(self, dataloader):
self.model.train()
total_loss = 0
for data, targets in dataloader:
loss, loss_each = self.train_one_batch(data, targets)
total_loss += loss
self.progress.one_batch(loss_each)
if self.scheduler:
self.scheduler.step()
return total_loss / len(dataloader)
def save_checkpoint(self, epoch: int, filename="checkpoint.pt"):
checkpoint = {
"epoch": epoch,
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
}
if self.ema:
self.ema.apply_shadow()
checkpoint["model_state_dict_ema"] = self.model.state_dict()
self.ema.restore()
torch.save(checkpoint, filename)
def train(self, dataloader, num_epochs):
logger.info("🚄 Start Training!")
with self.progress.progress:
self.progress.start_train(num_epochs)
for epoch in range(num_epochs):
self.progress.start_one_epoch(len(dataloader), self.optimizer, epoch)
epoch_loss = self.train_one_epoch(dataloader)
self.progress.finish_one_epoch()
logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
if (epoch + 1) % 5 == 0:
self.save_checkpoint(epoch, f"checkpoint_epoch_{epoch+1}.pth")
|