import torch from loguru import logger from tqdm import tqdm from config.config import TrainConfig from model.yolo import YOLO from tools.model_helper import EMA, get_optimizer, get_scheduler from utils.loss import get_loss_function class Trainer: def __init__(self, model: YOLO, train_cfg: TrainConfig, device): self.model = model.to(device) self.device = device self.optimizer = get_optimizer(model.parameters(), train_cfg.optimizer) self.scheduler = get_scheduler(self.optimizer, train_cfg.scheduler) self.loss_fn = get_loss_function() if train_cfg.ema.get("enabled", False): self.ema = EMA(model, decay=train_cfg.ema.decay) else: self.ema = None def train_one_batch(self, data, targets): data, targets = data.to(self.device), targets.to(self.device) self.optimizer.zero_grad() outputs = self.model(data) loss = self.loss_fn(outputs, targets) loss.backward() self.optimizer.step() if self.ema: self.ema.update() return loss.item() def train_one_epoch(self, dataloader): self.model.train() total_loss = 0 for data, targets in tqdm(dataloader, desc="Training"): loss = self.train_one_batch(data, targets) total_loss += loss if self.scheduler: self.scheduler.step() return total_loss / len(dataloader) def save_checkpoint(self, epoch, filename="checkpoint.pt"): checkpoint = { "epoch": epoch, "model_state_dict": self.model.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), } if self.ema: self.ema.apply_shadow() checkpoint["model_state_dict_ema"] = self.model.state_dict() self.ema.restore() torch.save(checkpoint, filename) def train(self, dataloader, num_epochs): logger.info("start train") for epoch in range(num_epochs): epoch_loss = self.train_one_epoch(dataloader) logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}") if (epoch + 1) % 5 == 0: self.save_checkpoint(epoch, f"checkpoint_epoch_{epoch+1}.pth")