🔧 [Update] Trainer input config
Browse files- examples/example_train.py +1 -1
- yolo/tools/trainer.py +3 -1
examples/example_train.py
CHANGED
@@ -28,7 +28,7 @@ def main(cfg: Config):
|
|
28 |
# TODO: get_device or rank, for DDP mode
|
29 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
|
31 |
-
trainer = Trainer(model, cfg
|
32 |
trainer.train(dataloader, 10)
|
33 |
|
34 |
|
|
|
28 |
# TODO: get_device or rank, for DDP mode
|
29 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
|
31 |
+
trainer = Trainer(model, cfg, device)
|
32 |
trainer.train(dataloader, 10)
|
33 |
|
34 |
|
yolo/tools/trainer.py
CHANGED
@@ -9,7 +9,9 @@ from yolo.utils.loss import get_loss_function
|
|
9 |
|
10 |
|
11 |
class Trainer:
|
12 |
-
def __init__(self, model: YOLO,
|
|
|
|
|
13 |
self.model = model.to(device)
|
14 |
self.device = device
|
15 |
self.optimizer = get_optimizer(model.parameters(), train_cfg.optimizer)
|
|
|
9 |
|
10 |
|
11 |
class Trainer:
|
12 |
+
def __init__(self, model: YOLO, cfg: Config, device):
|
13 |
+
train_cfg: TrainConfig = cfg.hyper.train
|
14 |
+
|
15 |
self.model = model.to(device)
|
16 |
self.device = device
|
17 |
self.optimizer = get_optimizer(model.parameters(), train_cfg.optimizer)
|