π [Rename] class Trainer -> Solver
Browse files
yolo/tools/{trainer.py β solver.py}
RENAMED
@@ -6,8 +6,10 @@ from torch import Tensor
|
|
6 |
from torch.cuda.amp import GradScaler, autocast
|
7 |
|
8 |
from yolo.config.config import Config, TrainConfig
|
9 |
-
from yolo.model.yolo import
|
|
|
10 |
from yolo.tools.loss_functions import get_loss_function
|
|
|
11 |
from yolo.utils.logging_utils import ProgressTracker
|
12 |
from yolo.utils.model_utils import (
|
13 |
ExponentialMovingAverage,
|
@@ -17,16 +19,15 @@ from yolo.utils.model_utils import (
|
|
17 |
|
18 |
|
19 |
class ModelTrainer:
|
20 |
-
def __init__(self, cfg: Config, save_path: str, device):
|
21 |
train_cfg: TrainConfig = cfg.task
|
22 |
-
model =
|
23 |
-
|
24 |
-
self.model = model.to(device)
|
25 |
self.device = device
|
26 |
self.optimizer = create_optimizer(model, train_cfg.optimizer)
|
27 |
self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
|
28 |
self.loss_fn = get_loss_function(cfg)
|
29 |
-
self.progress = ProgressTracker(cfg, save_path, use_wandb
|
|
|
30 |
|
31 |
if getattr(train_cfg.ema, "enabled", False):
|
32 |
self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
|
@@ -75,8 +76,9 @@ class ModelTrainer:
|
|
75 |
self.ema.restore()
|
76 |
torch.save(checkpoint, filename)
|
77 |
|
78 |
-
def
|
79 |
logger.info("π Start Training!")
|
|
|
80 |
|
81 |
with self.progress.progress:
|
82 |
self.progress.start_train(num_epochs)
|
@@ -89,3 +91,27 @@ class ModelTrainer:
|
|
89 |
logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
|
90 |
if (epoch + 1) % 5 == 0:
|
91 |
self.save_checkpoint(epoch, f"checkpoint_epoch_{epoch+1}.pth")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from torch.cuda.amp import GradScaler, autocast
|
7 |
|
8 |
from yolo.config.config import Config, TrainConfig
|
9 |
+
from yolo.model.yolo import YOLO
|
10 |
+
from yolo.tools.drawer import draw_bboxes
|
11 |
from yolo.tools.loss_functions import get_loss_function
|
12 |
+
from yolo.utils.bounding_box_utils import AnchorBoxConverter, bbox_nms
|
13 |
from yolo.utils.logging_utils import ProgressTracker
|
14 |
from yolo.utils.model_utils import (
|
15 |
ExponentialMovingAverage,
|
|
|
19 |
|
20 |
|
21 |
class ModelTrainer:
|
22 |
+
def __init__(self, cfg: Config, model: YOLO, save_path: str, device):
|
23 |
train_cfg: TrainConfig = cfg.task
|
24 |
+
self.model = model
|
|
|
|
|
25 |
self.device = device
|
26 |
self.optimizer = create_optimizer(model, train_cfg.optimizer)
|
27 |
self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
|
28 |
self.loss_fn = get_loss_function(cfg)
|
29 |
+
self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb)
|
30 |
+
self.num_epochs = cfg.task.epoch
|
31 |
|
32 |
if getattr(train_cfg.ema, "enabled", False):
|
33 |
self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
|
|
|
76 |
self.ema.restore()
|
77 |
torch.save(checkpoint, filename)
|
78 |
|
79 |
+
def solve(self, dataloader):
|
80 |
logger.info("π Start Training!")
|
81 |
+
num_epochs = self.num_epochs
|
82 |
|
83 |
with self.progress.progress:
|
84 |
self.progress.start_train(num_epochs)
|
|
|
91 |
logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
|
92 |
if (epoch + 1) % 5 == 0:
|
93 |
self.save_checkpoint(epoch, f"checkpoint_epoch_{epoch+1}.pth")
|
94 |
+
|
95 |
+
|
96 |
+
class ModelTester:
|
97 |
+
def __init__(self, cfg: Config, model: YOLO, save_path: str, device):
|
98 |
+
self.model = model
|
99 |
+
self.device = device
|
100 |
+
self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb)
|
101 |
+
|
102 |
+
self.anchor2box = AnchorBoxConverter(cfg, device)
|
103 |
+
self.nms = cfg.task.nms
|
104 |
+
self.save_path = save_path
|
105 |
+
|
106 |
+
def solve(self, dataloader):
|
107 |
+
logger.info("π Start Inference!")
|
108 |
+
|
109 |
+
for images, _ in dataloader:
|
110 |
+
images = images.to(self.device)
|
111 |
+
with torch.no_grad():
|
112 |
+
raw_output = self.model(images)
|
113 |
+
predict, _ = self.anchor2box(raw_output[0][3:], with_logits=True)
|
114 |
+
|
115 |
+
nms_out = bbox_nms(predict, self.nms)
|
116 |
+
for image, bbox in zip(images, nms_out):
|
117 |
+
draw_bboxes(image, bbox, scaled_bbox=False, save_path=self.save_path)
|