henry000 commited on
Commit
f0fdf9a
·
1 Parent(s): a51f159

🔧 [Update] Trainer input config

Browse files
examples/example_train.py CHANGED
@@ -28,7 +28,7 @@ def main(cfg: Config):
28
  # TODO: get_device or rank, for DDP mode
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
 
31
- trainer = Trainer(model, cfg.hyper.train, device)
32
  trainer.train(dataloader, 10)
33
 
34
 
 
28
  # TODO: get_device or rank, for DDP mode
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
 
31
+ trainer = Trainer(model, cfg, device)
32
  trainer.train(dataloader, 10)
33
 
34
 
yolo/tools/trainer.py CHANGED
@@ -9,7 +9,9 @@ from yolo.utils.loss import get_loss_function
9
 
10
 
11
  class Trainer:
12
- def __init__(self, model: YOLO, train_cfg: TrainConfig, device):
 
 
13
  self.model = model.to(device)
14
  self.device = device
15
  self.optimizer = get_optimizer(model.parameters(), train_cfg.optimizer)
 
9
 
10
 
11
  class Trainer:
12
+ def __init__(self, model: YOLO, cfg: Config, device):
13
+ train_cfg: TrainConfig = cfg.hyper.train
14
+
15
  self.model = model.to(device)
16
  self.device = device
17
  self.optimizer = get_optimizer(model.parameters(), train_cfg.optimizer)