YOLO / tools /trainer.py
henry000's picture
✨ [Init] Trainer for training whole model!
649c592
raw
history blame
2.28 kB
import torch
from loguru import logger
from tqdm import tqdm
from config.config import TrainConfig
from model.yolo import YOLO
from tools.model_helper import EMA, get_optimizer, get_scheduler
from utils.loss import get_loss_function
class Trainer:
def __init__(self, model: YOLO, train_cfg: TrainConfig, device):
self.model = model.to(device)
self.device = device
self.optimizer = get_optimizer(model.parameters(), train_cfg.optimizer)
self.scheduler = get_scheduler(self.optimizer, train_cfg.scheduler)
self.loss_fn = get_loss_function()
if train_cfg.ema.get("enabled", False):
self.ema = EMA(model, decay=train_cfg.ema.decay)
else:
self.ema = None
def train_one_batch(self, data, targets):
data, targets = data.to(self.device), targets.to(self.device)
self.optimizer.zero_grad()
outputs = self.model(data)
loss = self.loss_fn(outputs, targets)
loss.backward()
self.optimizer.step()
if self.ema:
self.ema.update()
return loss.item()
def train_one_epoch(self, dataloader):
self.model.train()
total_loss = 0
for data, targets in tqdm(dataloader, desc="Training"):
loss = self.train_one_batch(data, targets)
total_loss += loss
if self.scheduler:
self.scheduler.step()
return total_loss / len(dataloader)
def save_checkpoint(self, epoch, filename="checkpoint.pt"):
checkpoint = {
"epoch": epoch,
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
}
if self.ema:
self.ema.apply_shadow()
checkpoint["model_state_dict_ema"] = self.model.state_dict()
self.ema.restore()
torch.save(checkpoint, filename)
def train(self, dataloader, num_epochs):
logger.info("start train")
for epoch in range(num_epochs):
epoch_loss = self.train_one_epoch(dataloader)
logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
if (epoch + 1) % 5 == 0:
self.save_checkpoint(epoch, f"checkpoint_epoch_{epoch+1}.pth")