|
import torch |
|
from loguru import logger |
|
from torch import Tensor |
|
|
|
|
|
from torch.cuda.amp import GradScaler, autocast |
|
|
|
from yolo.config.config import Config, TrainConfig |
|
from yolo.model.yolo import YOLO |
|
from yolo.tools.data_loader import StreamDataLoader |
|
from yolo.tools.drawer import draw_bboxes |
|
from yolo.tools.loss_functions import get_loss_function |
|
from yolo.utils.bounding_box_utils import AnchorBoxConverter, bbox_nms |
|
from yolo.utils.logging_utils import ProgressTracker |
|
from yolo.utils.model_utils import ( |
|
ExponentialMovingAverage, |
|
create_optimizer, |
|
create_scheduler, |
|
) |
|
|
|
|
|
class ModelTrainer: |
|
def __init__(self, cfg: Config, model: YOLO, save_path: str, device): |
|
train_cfg: TrainConfig = cfg.task |
|
self.model = model |
|
self.device = device |
|
self.optimizer = create_optimizer(model, train_cfg.optimizer) |
|
self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler) |
|
self.loss_fn = get_loss_function(cfg) |
|
self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb) |
|
self.num_epochs = cfg.task.epoch |
|
|
|
if getattr(train_cfg.ema, "enabled", False): |
|
self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay) |
|
else: |
|
self.ema = None |
|
self.scaler = GradScaler() |
|
|
|
def train_one_batch(self, data: Tensor, targets: Tensor): |
|
data, targets = data.to(self.device), targets.to(self.device) |
|
self.optimizer.zero_grad() |
|
|
|
with autocast(): |
|
outputs = self.model(data) |
|
loss, loss_item = self.loss_fn(outputs, targets) |
|
|
|
self.scaler.scale(loss).backward() |
|
self.scaler.step(self.optimizer) |
|
self.scaler.update() |
|
|
|
return loss.item(), loss_item |
|
|
|
def train_one_epoch(self, dataloader): |
|
self.model.train() |
|
total_loss = 0 |
|
|
|
for data, targets in dataloader: |
|
loss, loss_each = self.train_one_batch(data, targets) |
|
|
|
total_loss += loss |
|
self.progress.one_batch(loss_each) |
|
|
|
if self.scheduler: |
|
self.scheduler.step() |
|
|
|
return total_loss / len(dataloader) |
|
|
|
def save_checkpoint(self, epoch: int, filename="checkpoint.pt"): |
|
checkpoint = { |
|
"epoch": epoch, |
|
"model_state_dict": self.model.state_dict(), |
|
"optimizer_state_dict": self.optimizer.state_dict(), |
|
} |
|
if self.ema: |
|
self.ema.apply_shadow() |
|
checkpoint["model_state_dict_ema"] = self.model.state_dict() |
|
self.ema.restore() |
|
torch.save(checkpoint, filename) |
|
|
|
def solve(self, dataloader): |
|
logger.info("π Start Training!") |
|
num_epochs = self.num_epochs |
|
|
|
with self.progress.progress: |
|
self.progress.start_train(num_epochs) |
|
for epoch in range(num_epochs): |
|
|
|
self.progress.start_one_epoch(len(dataloader), self.optimizer, epoch) |
|
epoch_loss = self.train_one_epoch(dataloader) |
|
self.progress.finish_one_epoch() |
|
|
|
logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}") |
|
if (epoch + 1) % 5 == 0: |
|
self.save_checkpoint(epoch, f"checkpoint_epoch_{epoch+1}.pth") |
|
|
|
|
|
class ModelTester: |
|
def __init__(self, cfg: Config, model: YOLO, save_path: str, device): |
|
self.model = model |
|
self.device = device |
|
self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb) |
|
|
|
self.anchor2box = AnchorBoxConverter(cfg, device) |
|
self.nms = cfg.task.nms |
|
self.save_path = save_path |
|
|
|
def solve(self, dataloader: StreamDataLoader): |
|
logger.info("π Start Inference!") |
|
|
|
try: |
|
for idx, images in enumerate(dataloader): |
|
images = images.to(self.device) |
|
with torch.no_grad(): |
|
raw_output = self.model(images) |
|
predict, _ = self.anchor2box(raw_output[0][3:], with_logits=True) |
|
nms_out = bbox_nms(predict, self.nms) |
|
draw_bboxes( |
|
images[0], nms_out[0], scaled_bbox=False, save_path=self.save_path, save_name=f"frame{idx:03d}.png" |
|
) |
|
except KeyboardInterrupt: |
|
logger.error("Interrupted by user") |
|
dataloader.stop_event.set() |
|
dataloader.stop() |
|
except Exception as e: |
|
logger.error(e) |
|
dataloader.stop_event.set() |
|
dataloader.stop() |
|
raise e |
|
dataloader.stop() |
|
|