π [New] DDP mode for training model
Browse files- yolo/lazy.py +6 -5
- yolo/tools/data_loader.py +7 -5
- yolo/tools/solver.py +7 -2
- yolo/utils/model_utils.py +21 -27
yolo/lazy.py
CHANGED
@@ -2,7 +2,6 @@ import sys
|
|
2 |
from pathlib import Path
|
3 |
|
4 |
import hydra
|
5 |
-
import torch
|
6 |
|
7 |
project_root = Path(__file__).resolve().parent.parent
|
8 |
sys.path.append(str(project_root))
|
@@ -14,22 +13,24 @@ from yolo.tools.solver import ModelTester, ModelTrainer
|
|
14 |
from yolo.utils.bounding_box_utils import Vec2Box
|
15 |
from yolo.utils.deploy_utils import FastModelLoader
|
16 |
from yolo.utils.logging_utils import ProgressLogger
|
17 |
-
from yolo.utils.model_utils import
|
18 |
|
19 |
|
20 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
21 |
def main(cfg: Config):
|
22 |
progress = ProgressLogger(cfg, exp_name=cfg.name)
|
23 |
-
|
|
|
24 |
if getattr(cfg.task, "fast_inference", False):
|
25 |
model = FastModelLoader(cfg).load_model()
|
26 |
else:
|
27 |
model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight)
|
28 |
-
|
|
|
29 |
vec2box = Vec2Box(model, cfg.image_size, device)
|
30 |
|
31 |
if cfg.task.task == "train":
|
32 |
-
trainer = ModelTrainer(cfg, model, vec2box, progress, device)
|
33 |
trainer.solve(dataloader)
|
34 |
|
35 |
if cfg.task.task == "inference":
|
|
|
2 |
from pathlib import Path
|
3 |
|
4 |
import hydra
|
|
|
5 |
|
6 |
project_root = Path(__file__).resolve().parent.parent
|
7 |
sys.path.append(str(project_root))
|
|
|
13 |
from yolo.utils.bounding_box_utils import Vec2Box
|
14 |
from yolo.utils.deploy_utils import FastModelLoader
|
15 |
from yolo.utils.logging_utils import ProgressLogger
|
16 |
+
from yolo.utils.model_utils import get_device
|
17 |
|
18 |
|
19 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
20 |
def main(cfg: Config):
|
21 |
progress = ProgressLogger(cfg, exp_name=cfg.name)
|
22 |
+
device, use_ddp = get_device(cfg.device)
|
23 |
+
dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task, use_ddp)
|
24 |
if getattr(cfg.task, "fast_inference", False):
|
25 |
model = FastModelLoader(cfg).load_model()
|
26 |
else:
|
27 |
model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight)
|
28 |
+
model = model.to(device)
|
29 |
+
|
30 |
vec2box = Vec2Box(model, cfg.image_size, device)
|
31 |
|
32 |
if cfg.task.task == "train":
|
33 |
+
trainer = ModelTrainer(cfg, model, vec2box, progress, device, use_ddp)
|
34 |
trainer.solve(dataloader)
|
35 |
|
36 |
if cfg.task.task == "inference":
|
yolo/tools/data_loader.py
CHANGED
@@ -12,7 +12,7 @@ from PIL import Image
|
|
12 |
from rich.progress import track
|
13 |
from torch import Tensor
|
14 |
from torch.utils.data import DataLoader, Dataset
|
15 |
-
from
|
16 |
|
17 |
from yolo.config.config import DataConfig, DatasetConfig
|
18 |
from yolo.tools.data_augmentation import (
|
@@ -157,14 +157,16 @@ class YoloDataset(Dataset):
|
|
157 |
|
158 |
|
159 |
class YoloDataLoader(DataLoader):
|
160 |
-
def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train"):
|
161 |
"""Initializes the YoloDataLoader with hydra-config files."""
|
162 |
dataset = YoloDataset(data_cfg, dataset_cfg, task)
|
|
|
163 |
self.image_size = data_cfg.image_size[0]
|
164 |
super().__init__(
|
165 |
dataset,
|
166 |
batch_size=data_cfg.batch_size,
|
167 |
-
|
|
|
168 |
num_workers=data_cfg.cpu_num,
|
169 |
pin_memory=data_cfg.pin_memory,
|
170 |
collate_fn=self.collate_fn,
|
@@ -198,14 +200,14 @@ class YoloDataLoader(DataLoader):
|
|
198 |
return batch_images, batch_targets
|
199 |
|
200 |
|
201 |
-
def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train"):
|
202 |
if task == "inference":
|
203 |
return StreamDataLoader(data_cfg)
|
204 |
|
205 |
if dataset_cfg.auto_download:
|
206 |
prepare_dataset(dataset_cfg, task)
|
207 |
|
208 |
-
return YoloDataLoader(data_cfg, dataset_cfg, task)
|
209 |
|
210 |
|
211 |
class StreamDataLoader:
|
|
|
12 |
from rich.progress import track
|
13 |
from torch import Tensor
|
14 |
from torch.utils.data import DataLoader, Dataset
|
15 |
+
from torch.utils.data.distributed import DistributedSampler
|
16 |
|
17 |
from yolo.config.config import DataConfig, DatasetConfig
|
18 |
from yolo.tools.data_augmentation import (
|
|
|
157 |
|
158 |
|
159 |
class YoloDataLoader(DataLoader):
|
160 |
+
def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train", use_ddp: bool = False):
|
161 |
"""Initializes the YoloDataLoader with hydra-config files."""
|
162 |
dataset = YoloDataset(data_cfg, dataset_cfg, task)
|
163 |
+
sampler = DistributedSampler(dataset, shuffle=data_cfg.shuffle) if use_ddp else None
|
164 |
self.image_size = data_cfg.image_size[0]
|
165 |
super().__init__(
|
166 |
dataset,
|
167 |
batch_size=data_cfg.batch_size,
|
168 |
+
sampler=sampler,
|
169 |
+
shuffle=data_cfg.shuffle and not use_ddp,
|
170 |
num_workers=data_cfg.cpu_num,
|
171 |
pin_memory=data_cfg.pin_memory,
|
172 |
collate_fn=self.collate_fn,
|
|
|
200 |
return batch_images, batch_targets
|
201 |
|
202 |
|
203 |
+
def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train", use_ddp: bool = False):
|
204 |
if task == "inference":
|
205 |
return StreamDataLoader(data_cfg)
|
206 |
|
207 |
if dataset_cfg.auto_download:
|
208 |
prepare_dataset(dataset_cfg, task)
|
209 |
|
210 |
+
return YoloDataLoader(data_cfg, dataset_cfg, task, use_ddp)
|
211 |
|
212 |
|
213 |
class StreamDataLoader:
|
yolo/tools/solver.py
CHANGED
@@ -7,6 +7,8 @@ from torch import Tensor
|
|
7 |
|
8 |
# TODO: We may can't use CUDA?
|
9 |
from torch.cuda.amp import GradScaler, autocast
|
|
|
|
|
10 |
|
11 |
from yolo.config.config import Config, TrainConfig, ValidationConfig
|
12 |
from yolo.model.yolo import YOLO
|
@@ -25,7 +27,8 @@ from yolo.utils.model_utils import (
|
|
25 |
class ModelTrainer:
|
26 |
def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device):
|
27 |
train_cfg: TrainConfig = cfg.task
|
28 |
-
self.model = model
|
|
|
29 |
self.vec2box = vec2box
|
30 |
self.device = device
|
31 |
self.optimizer = create_optimizer(model, train_cfg.optimizer)
|
@@ -86,13 +89,15 @@ class ModelTrainer:
|
|
86 |
self.ema.restore()
|
87 |
torch.save(checkpoint, filename)
|
88 |
|
89 |
-
def solve(self, dataloader):
|
90 |
logger.info("π Start Training!")
|
91 |
num_epochs = self.num_epochs
|
92 |
|
93 |
with self.progress.progress:
|
94 |
self.progress.start_train(num_epochs)
|
95 |
for epoch in range(num_epochs):
|
|
|
|
|
96 |
|
97 |
self.progress.start_one_epoch(len(dataloader), self.optimizer, epoch)
|
98 |
epoch_loss = self.train_one_epoch(dataloader)
|
|
|
7 |
|
8 |
# TODO: We may can't use CUDA?
|
9 |
from torch.cuda.amp import GradScaler, autocast
|
10 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
|
13 |
from yolo.config.config import Config, TrainConfig, ValidationConfig
|
14 |
from yolo.model.yolo import YOLO
|
|
|
27 |
class ModelTrainer:
|
28 |
def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device):
|
29 |
train_cfg: TrainConfig = cfg.task
|
30 |
+
self.model = model if not use_ddp else DDP(model, device_ids=[device])
|
31 |
+
self.use_ddp = use_ddp
|
32 |
self.vec2box = vec2box
|
33 |
self.device = device
|
34 |
self.optimizer = create_optimizer(model, train_cfg.optimizer)
|
|
|
89 |
self.ema.restore()
|
90 |
torch.save(checkpoint, filename)
|
91 |
|
92 |
+
def solve(self, dataloader: DataLoader):
|
93 |
logger.info("π Start Training!")
|
94 |
num_epochs = self.num_epochs
|
95 |
|
96 |
with self.progress.progress:
|
97 |
self.progress.start_train(num_epochs)
|
98 |
for epoch in range(num_epochs):
|
99 |
+
if self.use_ddp:
|
100 |
+
dataloader.sampler.set_epoch(epoch)
|
101 |
|
102 |
self.progress.start_one_epoch(len(dataloader), self.optimizer, epoch)
|
103 |
epoch_loss = self.train_one_epoch(dataloader)
|
yolo/utils/model_utils.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
-
|
|
|
2 |
|
3 |
import torch
|
4 |
import torch.distributed as dist
|
|
|
5 |
from omegaconf import ListConfig
|
6 |
from torch import nn
|
7 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
@@ -73,29 +75,21 @@ def create_scheduler(optimizer: Optimizer, schedule_cfg: SchedulerConfig) -> _LR
|
|
73 |
return schedule
|
74 |
|
75 |
|
76 |
-
def
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
dist.init_process_group(
|
95 |
-
backend="gloo" if torch.cuda.is_available() else "gloo", rank=dist.get_rank(), world_size=world_size
|
96 |
-
)
|
97 |
-
print(f"Initialized process group; rank: {dist.get_rank()}, size: {world_size}")
|
98 |
-
|
99 |
-
model = model.cuda(device)
|
100 |
-
model = DDP(model, device_ids=[device])
|
101 |
-
return device, model.to(device)
|
|
|
1 |
+
import os
|
2 |
+
from typing import List, Type, Union
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
+
from loguru import logger
|
7 |
from omegaconf import ListConfig
|
8 |
from torch import nn
|
9 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
75 |
return schedule
|
76 |
|
77 |
|
78 |
+
def initialize_distributed() -> None:
|
79 |
+
rank = int(os.getenv("RANK", "0"))
|
80 |
+
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
81 |
+
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
82 |
+
|
83 |
+
torch.cuda.set_device(local_rank)
|
84 |
+
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
85 |
+
logger.info(f"Initialized process group; rank: {rank}, size: {world_size}")
|
86 |
+
return local_rank
|
87 |
+
|
88 |
+
|
89 |
+
def get_device(device_spec: Union[str, int, List[int]]) -> torch.device:
|
90 |
+
ddp_flag = False
|
91 |
+
if isinstance(device_spec, (list, ListConfig)):
|
92 |
+
ddp_flag = True
|
93 |
+
device_spec = initialize_distributed()
|
94 |
+
device = torch.device(device_spec)
|
95 |
+
return device, ddp_flag
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|