henry000 commited on
Commit
649c592
·
1 Parent(s): 23db031

✨ [Init] Trainer for training whole model!

Browse files
config/config.py CHANGED
@@ -14,7 +14,59 @@ class Download:
14
  path: str
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  @dataclass
18
  class Config:
19
  model: Model
20
  download: Download
 
 
14
  path: str
15
 
16
 
17
+ @dataclass
18
+ class DataLoaderConfig:
19
+ batch_size: int
20
+ shuffle: bool
21
+ num_workers: int
22
+ pin_memory: bool
23
+
24
+
25
+ @dataclass
26
+ class OptimizerArgs:
27
+ lr: float
28
+ weight_decay: float
29
+
30
+
31
+ @dataclass
32
+ class OptimizerConfig:
33
+ type: str
34
+ args: OptimizerArgs
35
+
36
+
37
+ @dataclass
38
+ class SchedulerArgs:
39
+ step_size: int
40
+ gamma: float
41
+
42
+
43
+ @dataclass
44
+ class SchedulerConfig:
45
+ type: str
46
+ args: SchedulerArgs
47
+
48
+
49
+ @dataclass
50
+ class EMAConfig:
51
+ enabled: bool
52
+ decay: float
53
+
54
+
55
+ @dataclass
56
+ class TrainConfig:
57
+ optimizer: OptimizerConfig
58
+ scheduler: SchedulerConfig
59
+ ema: EMAConfig
60
+
61
+
62
+ @dataclass
63
+ class HyperConfig:
64
+ data: DataLoaderConfig
65
+ train: TrainConfig
66
+
67
+
68
  @dataclass
69
  class Config:
70
  model: Model
71
  download: Download
72
+ hyper: HyperConfig
config/hyper/default.yaml CHANGED
@@ -3,3 +3,17 @@ data:
3
  shuffle: True
4
  num_workers: 4
5
  pin_memory: True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  shuffle: True
4
  num_workers: 4
5
  pin_memory: True
6
+ train:
7
+ optimizer:
8
+ type: Adam
9
+ args:
10
+ lr: 0.001
11
+ weight_decay: 0.0001
12
+ scheduler:
13
+ type: StepLR
14
+ args:
15
+ step_size: 10
16
+ gamma: 0.1
17
+ ema:
18
+ enabled: true
19
+ decay: 0.995
tools/model_helper.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Type
2
+
3
+ import torch
4
+ from torch.optim import Optimizer
5
+ from torch.optim.lr_scheduler import _LRScheduler
6
+
7
+ from config.config import OptimizerConfig, SchedulerConfig
8
+
9
+
10
+ class EMA:
11
+ def __init__(self, model: torch.nn.Module, decay: float):
12
+ self.model = model
13
+ self.decay = decay
14
+ self.shadow = {name: param.clone().detach() for name, param in model.named_parameters()}
15
+
16
+ def update(self):
17
+ """Update the shadow parameters using the current model parameters."""
18
+ for name, param in self.model.named_parameters():
19
+ assert name in self.shadow, "All model parameters should have a corresponding shadow parameter."
20
+ new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
21
+ self.shadow[name] = new_average.clone()
22
+
23
+ def apply_shadow(self):
24
+ """Apply the shadow parameters to the model."""
25
+ for name, param in self.model.named_parameters():
26
+ param.data.copy_(self.shadow[name])
27
+
28
+ def restore(self):
29
+ """Restore the original parameters from the shadow."""
30
+ for name, param in self.model.named_parameters():
31
+ self.shadow[name].copy_(param.data)
32
+
33
+
34
+ def get_optimizer(model_parameters, optim_cfg: OptimizerConfig) -> Optimizer:
35
+ """Create an optimizer for the given model parameters based on the configuration.
36
+
37
+ Returns:
38
+ An instance of the optimizer configured according to the provided settings.
39
+ """
40
+ optimizer_class: Type[Optimizer] = getattr(torch.optim, optim_cfg.type)
41
+ return optimizer_class(model_parameters, **optim_cfg.args)
42
+
43
+
44
+ def get_scheduler(optimizer: Optimizer, schedul_cfg: SchedulerConfig) -> _LRScheduler:
45
+ """Create a learning rate scheduler for the given optimizer based on the configuration.
46
+
47
+ Returns:
48
+ An instance of the scheduler configured according to the provided settings.
49
+ """
50
+ scheduler_class: Type[_LRScheduler] = getattr(torch.optim.lr_scheduler, schedul_cfg.type)
51
+ return scheduler_class(optimizer, **schedul_cfg.args)
tools/trainer.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from loguru import logger
3
+ from tqdm import tqdm
4
+
5
+ from config.config import TrainConfig
6
+ from model.yolo import YOLO
7
+ from tools.model_helper import EMA, get_optimizer, get_scheduler
8
+ from utils.loss import get_loss_function
9
+
10
+
11
+ class Trainer:
12
+ def __init__(self, model: YOLO, train_cfg: TrainConfig, device):
13
+ self.model = model.to(device)
14
+ self.device = device
15
+ self.optimizer = get_optimizer(model.parameters(), train_cfg.optimizer)
16
+ self.scheduler = get_scheduler(self.optimizer, train_cfg.scheduler)
17
+ self.loss_fn = get_loss_function()
18
+
19
+ if train_cfg.ema.get("enabled", False):
20
+ self.ema = EMA(model, decay=train_cfg.ema.decay)
21
+ else:
22
+ self.ema = None
23
+
24
+ def train_one_batch(self, data, targets):
25
+ data, targets = data.to(self.device), targets.to(self.device)
26
+ self.optimizer.zero_grad()
27
+ outputs = self.model(data)
28
+ loss = self.loss_fn(outputs, targets)
29
+ loss.backward()
30
+ self.optimizer.step()
31
+ if self.ema:
32
+ self.ema.update()
33
+ return loss.item()
34
+
35
+ def train_one_epoch(self, dataloader):
36
+ self.model.train()
37
+ total_loss = 0
38
+ for data, targets in tqdm(dataloader, desc="Training"):
39
+ loss = self.train_one_batch(data, targets)
40
+ total_loss += loss
41
+ if self.scheduler:
42
+ self.scheduler.step()
43
+ return total_loss / len(dataloader)
44
+
45
+ def save_checkpoint(self, epoch, filename="checkpoint.pt"):
46
+ checkpoint = {
47
+ "epoch": epoch,
48
+ "model_state_dict": self.model.state_dict(),
49
+ "optimizer_state_dict": self.optimizer.state_dict(),
50
+ }
51
+ if self.ema:
52
+ self.ema.apply_shadow()
53
+ checkpoint["model_state_dict_ema"] = self.model.state_dict()
54
+ self.ema.restore()
55
+ torch.save(checkpoint, filename)
56
+
57
+ def train(self, dataloader, num_epochs):
58
+ logger.info("start train")
59
+ for epoch in range(num_epochs):
60
+ epoch_loss = self.train_one_epoch(dataloader)
61
+ logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
62
+ if (epoch + 1) % 5 == 0:
63
+ self.save_checkpoint(epoch, f"checkpoint_epoch_{epoch+1}.pth")
train.py CHANGED
@@ -1,20 +1,27 @@
1
  import hydra
 
2
  from loguru import logger
3
 
4
  from config.config import Config
5
  from model.yolo import get_model
6
  from tools.log_helper import custom_logger
7
- from utils.dataloader import YoloDataset
 
8
  from utils.get_dataset import prepare_dataset
9
 
10
 
11
  @hydra.main(config_path="config", config_name="config", version_base=None)
12
  def main(cfg: Config):
13
- dataset = YoloDataset(cfg)
14
  if cfg.download.auto:
15
  prepare_dataset(cfg.download)
16
 
 
17
  model = get_model(cfg.model)
 
 
 
 
 
18
 
19
 
20
  if __name__ == "__main__":
 
1
  import hydra
2
+ import torch
3
  from loguru import logger
4
 
5
  from config.config import Config
6
  from model.yolo import get_model
7
  from tools.log_helper import custom_logger
8
+ from tools.trainer import Trainer
9
+ from utils.dataloader import get_dataloader
10
  from utils.get_dataset import prepare_dataset
11
 
12
 
13
  @hydra.main(config_path="config", config_name="config", version_base=None)
14
  def main(cfg: Config):
 
15
  if cfg.download.auto:
16
  prepare_dataset(cfg.download)
17
 
18
+ dataloader = get_dataloader(cfg)
19
  model = get_model(cfg.model)
20
+ # TODO: get_device or rank, for DDP mode
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+
23
+ trainer = Trainer(model, cfg.hyper.train, device)
24
+ trainer.train(dataloader, 10)
25
 
26
 
27
  if __name__ == "__main__":
utils/loss.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ def get_loss_function(*args, **kwargs):
2
+ raise NotImplementedError