✨ [Init] Trainer for training whole model!
Browse files- config/config.py +52 -0
- config/hyper/default.yaml +14 -0
- tools/model_helper.py +51 -0
- tools/trainer.py +63 -0
- train.py +9 -2
- utils/loss.py +2 -0
config/config.py
CHANGED
@@ -14,7 +14,59 @@ class Download:
|
|
14 |
path: str
|
15 |
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
@dataclass
|
18 |
class Config:
|
19 |
model: Model
|
20 |
download: Download
|
|
|
|
14 |
path: str
|
15 |
|
16 |
|
17 |
+
@dataclass
|
18 |
+
class DataLoaderConfig:
|
19 |
+
batch_size: int
|
20 |
+
shuffle: bool
|
21 |
+
num_workers: int
|
22 |
+
pin_memory: bool
|
23 |
+
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class OptimizerArgs:
|
27 |
+
lr: float
|
28 |
+
weight_decay: float
|
29 |
+
|
30 |
+
|
31 |
+
@dataclass
|
32 |
+
class OptimizerConfig:
|
33 |
+
type: str
|
34 |
+
args: OptimizerArgs
|
35 |
+
|
36 |
+
|
37 |
+
@dataclass
|
38 |
+
class SchedulerArgs:
|
39 |
+
step_size: int
|
40 |
+
gamma: float
|
41 |
+
|
42 |
+
|
43 |
+
@dataclass
|
44 |
+
class SchedulerConfig:
|
45 |
+
type: str
|
46 |
+
args: SchedulerArgs
|
47 |
+
|
48 |
+
|
49 |
+
@dataclass
|
50 |
+
class EMAConfig:
|
51 |
+
enabled: bool
|
52 |
+
decay: float
|
53 |
+
|
54 |
+
|
55 |
+
@dataclass
|
56 |
+
class TrainConfig:
|
57 |
+
optimizer: OptimizerConfig
|
58 |
+
scheduler: SchedulerConfig
|
59 |
+
ema: EMAConfig
|
60 |
+
|
61 |
+
|
62 |
+
@dataclass
|
63 |
+
class HyperConfig:
|
64 |
+
data: DataLoaderConfig
|
65 |
+
train: TrainConfig
|
66 |
+
|
67 |
+
|
68 |
@dataclass
|
69 |
class Config:
|
70 |
model: Model
|
71 |
download: Download
|
72 |
+
hyper: HyperConfig
|
config/hyper/default.yaml
CHANGED
@@ -3,3 +3,17 @@ data:
|
|
3 |
shuffle: True
|
4 |
num_workers: 4
|
5 |
pin_memory: True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
shuffle: True
|
4 |
num_workers: 4
|
5 |
pin_memory: True
|
6 |
+
train:
|
7 |
+
optimizer:
|
8 |
+
type: Adam
|
9 |
+
args:
|
10 |
+
lr: 0.001
|
11 |
+
weight_decay: 0.0001
|
12 |
+
scheduler:
|
13 |
+
type: StepLR
|
14 |
+
args:
|
15 |
+
step_size: 10
|
16 |
+
gamma: 0.1
|
17 |
+
ema:
|
18 |
+
enabled: true
|
19 |
+
decay: 0.995
|
tools/model_helper.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Type
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.optim import Optimizer
|
5 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
6 |
+
|
7 |
+
from config.config import OptimizerConfig, SchedulerConfig
|
8 |
+
|
9 |
+
|
10 |
+
class EMA:
|
11 |
+
def __init__(self, model: torch.nn.Module, decay: float):
|
12 |
+
self.model = model
|
13 |
+
self.decay = decay
|
14 |
+
self.shadow = {name: param.clone().detach() for name, param in model.named_parameters()}
|
15 |
+
|
16 |
+
def update(self):
|
17 |
+
"""Update the shadow parameters using the current model parameters."""
|
18 |
+
for name, param in self.model.named_parameters():
|
19 |
+
assert name in self.shadow, "All model parameters should have a corresponding shadow parameter."
|
20 |
+
new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
|
21 |
+
self.shadow[name] = new_average.clone()
|
22 |
+
|
23 |
+
def apply_shadow(self):
|
24 |
+
"""Apply the shadow parameters to the model."""
|
25 |
+
for name, param in self.model.named_parameters():
|
26 |
+
param.data.copy_(self.shadow[name])
|
27 |
+
|
28 |
+
def restore(self):
|
29 |
+
"""Restore the original parameters from the shadow."""
|
30 |
+
for name, param in self.model.named_parameters():
|
31 |
+
self.shadow[name].copy_(param.data)
|
32 |
+
|
33 |
+
|
34 |
+
def get_optimizer(model_parameters, optim_cfg: OptimizerConfig) -> Optimizer:
|
35 |
+
"""Create an optimizer for the given model parameters based on the configuration.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
An instance of the optimizer configured according to the provided settings.
|
39 |
+
"""
|
40 |
+
optimizer_class: Type[Optimizer] = getattr(torch.optim, optim_cfg.type)
|
41 |
+
return optimizer_class(model_parameters, **optim_cfg.args)
|
42 |
+
|
43 |
+
|
44 |
+
def get_scheduler(optimizer: Optimizer, schedul_cfg: SchedulerConfig) -> _LRScheduler:
|
45 |
+
"""Create a learning rate scheduler for the given optimizer based on the configuration.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
An instance of the scheduler configured according to the provided settings.
|
49 |
+
"""
|
50 |
+
scheduler_class: Type[_LRScheduler] = getattr(torch.optim.lr_scheduler, schedul_cfg.type)
|
51 |
+
return scheduler_class(optimizer, **schedul_cfg.args)
|
tools/trainer.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from loguru import logger
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
from config.config import TrainConfig
|
6 |
+
from model.yolo import YOLO
|
7 |
+
from tools.model_helper import EMA, get_optimizer, get_scheduler
|
8 |
+
from utils.loss import get_loss_function
|
9 |
+
|
10 |
+
|
11 |
+
class Trainer:
|
12 |
+
def __init__(self, model: YOLO, train_cfg: TrainConfig, device):
|
13 |
+
self.model = model.to(device)
|
14 |
+
self.device = device
|
15 |
+
self.optimizer = get_optimizer(model.parameters(), train_cfg.optimizer)
|
16 |
+
self.scheduler = get_scheduler(self.optimizer, train_cfg.scheduler)
|
17 |
+
self.loss_fn = get_loss_function()
|
18 |
+
|
19 |
+
if train_cfg.ema.get("enabled", False):
|
20 |
+
self.ema = EMA(model, decay=train_cfg.ema.decay)
|
21 |
+
else:
|
22 |
+
self.ema = None
|
23 |
+
|
24 |
+
def train_one_batch(self, data, targets):
|
25 |
+
data, targets = data.to(self.device), targets.to(self.device)
|
26 |
+
self.optimizer.zero_grad()
|
27 |
+
outputs = self.model(data)
|
28 |
+
loss = self.loss_fn(outputs, targets)
|
29 |
+
loss.backward()
|
30 |
+
self.optimizer.step()
|
31 |
+
if self.ema:
|
32 |
+
self.ema.update()
|
33 |
+
return loss.item()
|
34 |
+
|
35 |
+
def train_one_epoch(self, dataloader):
|
36 |
+
self.model.train()
|
37 |
+
total_loss = 0
|
38 |
+
for data, targets in tqdm(dataloader, desc="Training"):
|
39 |
+
loss = self.train_one_batch(data, targets)
|
40 |
+
total_loss += loss
|
41 |
+
if self.scheduler:
|
42 |
+
self.scheduler.step()
|
43 |
+
return total_loss / len(dataloader)
|
44 |
+
|
45 |
+
def save_checkpoint(self, epoch, filename="checkpoint.pt"):
|
46 |
+
checkpoint = {
|
47 |
+
"epoch": epoch,
|
48 |
+
"model_state_dict": self.model.state_dict(),
|
49 |
+
"optimizer_state_dict": self.optimizer.state_dict(),
|
50 |
+
}
|
51 |
+
if self.ema:
|
52 |
+
self.ema.apply_shadow()
|
53 |
+
checkpoint["model_state_dict_ema"] = self.model.state_dict()
|
54 |
+
self.ema.restore()
|
55 |
+
torch.save(checkpoint, filename)
|
56 |
+
|
57 |
+
def train(self, dataloader, num_epochs):
|
58 |
+
logger.info("start train")
|
59 |
+
for epoch in range(num_epochs):
|
60 |
+
epoch_loss = self.train_one_epoch(dataloader)
|
61 |
+
logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
|
62 |
+
if (epoch + 1) % 5 == 0:
|
63 |
+
self.save_checkpoint(epoch, f"checkpoint_epoch_{epoch+1}.pth")
|
train.py
CHANGED
@@ -1,20 +1,27 @@
|
|
1 |
import hydra
|
|
|
2 |
from loguru import logger
|
3 |
|
4 |
from config.config import Config
|
5 |
from model.yolo import get_model
|
6 |
from tools.log_helper import custom_logger
|
7 |
-
from
|
|
|
8 |
from utils.get_dataset import prepare_dataset
|
9 |
|
10 |
|
11 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
12 |
def main(cfg: Config):
|
13 |
-
dataset = YoloDataset(cfg)
|
14 |
if cfg.download.auto:
|
15 |
prepare_dataset(cfg.download)
|
16 |
|
|
|
17 |
model = get_model(cfg.model)
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
|
20 |
if __name__ == "__main__":
|
|
|
1 |
import hydra
|
2 |
+
import torch
|
3 |
from loguru import logger
|
4 |
|
5 |
from config.config import Config
|
6 |
from model.yolo import get_model
|
7 |
from tools.log_helper import custom_logger
|
8 |
+
from tools.trainer import Trainer
|
9 |
+
from utils.dataloader import get_dataloader
|
10 |
from utils.get_dataset import prepare_dataset
|
11 |
|
12 |
|
13 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
14 |
def main(cfg: Config):
|
|
|
15 |
if cfg.download.auto:
|
16 |
prepare_dataset(cfg.download)
|
17 |
|
18 |
+
dataloader = get_dataloader(cfg)
|
19 |
model = get_model(cfg.model)
|
20 |
+
# TODO: get_device or rank, for DDP mode
|
21 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
22 |
+
|
23 |
+
trainer = Trainer(model, cfg.hyper.train, device)
|
24 |
+
trainer.train(dataloader, 10)
|
25 |
|
26 |
|
27 |
if __name__ == "__main__":
|
utils/loss.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
def get_loss_function(*args, **kwargs):
|
2 |
+
raise NotImplementedError
|