henry000 commited on
Commit
9eb2d4e
Β·
1 Parent(s): 8f0b970

🚚 [Rename] class Trainer -> Solver

Browse files
yolo/tools/{trainer.py β†’ solver.py} RENAMED
@@ -6,8 +6,10 @@ from torch import Tensor
6
  from torch.cuda.amp import GradScaler, autocast
7
 
8
  from yolo.config.config import Config, TrainConfig
9
- from yolo.model.yolo import get_model
 
10
  from yolo.tools.loss_functions import get_loss_function
 
11
  from yolo.utils.logging_utils import ProgressTracker
12
  from yolo.utils.model_utils import (
13
  ExponentialMovingAverage,
@@ -17,16 +19,15 @@ from yolo.utils.model_utils import (
17
 
18
 
19
  class ModelTrainer:
20
- def __init__(self, cfg: Config, save_path: str, device):
21
  train_cfg: TrainConfig = cfg.task
22
- model = get_model(cfg)
23
-
24
- self.model = model.to(device)
25
  self.device = device
26
  self.optimizer = create_optimizer(model, train_cfg.optimizer)
27
  self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
28
  self.loss_fn = get_loss_function(cfg)
29
- self.progress = ProgressTracker(cfg, save_path, use_wandb=True)
 
30
 
31
  if getattr(train_cfg.ema, "enabled", False):
32
  self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
@@ -75,8 +76,9 @@ class ModelTrainer:
75
  self.ema.restore()
76
  torch.save(checkpoint, filename)
77
 
78
- def train(self, dataloader, num_epochs):
79
  logger.info("πŸš„ Start Training!")
 
80
 
81
  with self.progress.progress:
82
  self.progress.start_train(num_epochs)
@@ -89,3 +91,27 @@ class ModelTrainer:
89
  logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
90
  if (epoch + 1) % 5 == 0:
91
  self.save_checkpoint(epoch, f"checkpoint_epoch_{epoch+1}.pth")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from torch.cuda.amp import GradScaler, autocast
7
 
8
  from yolo.config.config import Config, TrainConfig
9
+ from yolo.model.yolo import YOLO
10
+ from yolo.tools.drawer import draw_bboxes
11
  from yolo.tools.loss_functions import get_loss_function
12
+ from yolo.utils.bounding_box_utils import AnchorBoxConverter, bbox_nms
13
  from yolo.utils.logging_utils import ProgressTracker
14
  from yolo.utils.model_utils import (
15
  ExponentialMovingAverage,
 
19
 
20
 
21
  class ModelTrainer:
22
+ def __init__(self, cfg: Config, model: YOLO, save_path: str, device):
23
  train_cfg: TrainConfig = cfg.task
24
+ self.model = model
 
 
25
  self.device = device
26
  self.optimizer = create_optimizer(model, train_cfg.optimizer)
27
  self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
28
  self.loss_fn = get_loss_function(cfg)
29
+ self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb)
30
+ self.num_epochs = cfg.task.epoch
31
 
32
  if getattr(train_cfg.ema, "enabled", False):
33
  self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
 
76
  self.ema.restore()
77
  torch.save(checkpoint, filename)
78
 
79
+ def solve(self, dataloader):
80
  logger.info("πŸš„ Start Training!")
81
+ num_epochs = self.num_epochs
82
 
83
  with self.progress.progress:
84
  self.progress.start_train(num_epochs)
 
91
  logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
92
  if (epoch + 1) % 5 == 0:
93
  self.save_checkpoint(epoch, f"checkpoint_epoch_{epoch+1}.pth")
94
+
95
+
96
+ class ModelTester:
97
+ def __init__(self, cfg: Config, model: YOLO, save_path: str, device):
98
+ self.model = model
99
+ self.device = device
100
+ self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb)
101
+
102
+ self.anchor2box = AnchorBoxConverter(cfg, device)
103
+ self.nms = cfg.task.nms
104
+ self.save_path = save_path
105
+
106
+ def solve(self, dataloader):
107
+ logger.info("πŸ‘€ Start Inference!")
108
+
109
+ for images, _ in dataloader:
110
+ images = images.to(self.device)
111
+ with torch.no_grad():
112
+ raw_output = self.model(images)
113
+ predict, _ = self.anchor2box(raw_output[0][3:], with_logits=True)
114
+
115
+ nms_out = bbox_nms(predict, self.nms)
116
+ for image, bbox in zip(images, nms_out):
117
+ draw_bboxes(image, bbox, scaled_bbox=False, save_path=self.save_path)