henry000 commited on
Commit
1fe2937
Β·
1 Parent(s): 1e3931d

πŸš€ [New] DDP mode for training model

Browse files
yolo/lazy.py CHANGED
@@ -2,7 +2,6 @@ import sys
2
  from pathlib import Path
3
 
4
  import hydra
5
- import torch
6
 
7
  project_root = Path(__file__).resolve().parent.parent
8
  sys.path.append(str(project_root))
@@ -14,22 +13,24 @@ from yolo.tools.solver import ModelTester, ModelTrainer
14
  from yolo.utils.bounding_box_utils import Vec2Box
15
  from yolo.utils.deploy_utils import FastModelLoader
16
  from yolo.utils.logging_utils import ProgressLogger
17
- from yolo.utils.model_utils import send_to_device
18
 
19
 
20
  @hydra.main(config_path="config", config_name="config", version_base=None)
21
  def main(cfg: Config):
22
  progress = ProgressLogger(cfg, exp_name=cfg.name)
23
- dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
 
24
  if getattr(cfg.task, "fast_inference", False):
25
  model = FastModelLoader(cfg).load_model()
26
  else:
27
  model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight)
28
- device, model = send_to_device(model, cfg.device)
 
29
  vec2box = Vec2Box(model, cfg.image_size, device)
30
 
31
  if cfg.task.task == "train":
32
- trainer = ModelTrainer(cfg, model, vec2box, progress, device)
33
  trainer.solve(dataloader)
34
 
35
  if cfg.task.task == "inference":
 
2
  from pathlib import Path
3
 
4
  import hydra
 
5
 
6
  project_root = Path(__file__).resolve().parent.parent
7
  sys.path.append(str(project_root))
 
13
  from yolo.utils.bounding_box_utils import Vec2Box
14
  from yolo.utils.deploy_utils import FastModelLoader
15
  from yolo.utils.logging_utils import ProgressLogger
16
+ from yolo.utils.model_utils import get_device
17
 
18
 
19
  @hydra.main(config_path="config", config_name="config", version_base=None)
20
  def main(cfg: Config):
21
  progress = ProgressLogger(cfg, exp_name=cfg.name)
22
+ device, use_ddp = get_device(cfg.device)
23
+ dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task, use_ddp)
24
  if getattr(cfg.task, "fast_inference", False):
25
  model = FastModelLoader(cfg).load_model()
26
  else:
27
  model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight)
28
+ model = model.to(device)
29
+
30
  vec2box = Vec2Box(model, cfg.image_size, device)
31
 
32
  if cfg.task.task == "train":
33
+ trainer = ModelTrainer(cfg, model, vec2box, progress, device, use_ddp)
34
  trainer.solve(dataloader)
35
 
36
  if cfg.task.task == "inference":
yolo/tools/data_loader.py CHANGED
@@ -12,7 +12,7 @@ from PIL import Image
12
  from rich.progress import track
13
  from torch import Tensor
14
  from torch.utils.data import DataLoader, Dataset
15
- from torchvision.transforms import functional as TF
16
 
17
  from yolo.config.config import DataConfig, DatasetConfig
18
  from yolo.tools.data_augmentation import (
@@ -157,14 +157,16 @@ class YoloDataset(Dataset):
157
 
158
 
159
  class YoloDataLoader(DataLoader):
160
- def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train"):
161
  """Initializes the YoloDataLoader with hydra-config files."""
162
  dataset = YoloDataset(data_cfg, dataset_cfg, task)
 
163
  self.image_size = data_cfg.image_size[0]
164
  super().__init__(
165
  dataset,
166
  batch_size=data_cfg.batch_size,
167
- shuffle=data_cfg.shuffle,
 
168
  num_workers=data_cfg.cpu_num,
169
  pin_memory=data_cfg.pin_memory,
170
  collate_fn=self.collate_fn,
@@ -198,14 +200,14 @@ class YoloDataLoader(DataLoader):
198
  return batch_images, batch_targets
199
 
200
 
201
- def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train"):
202
  if task == "inference":
203
  return StreamDataLoader(data_cfg)
204
 
205
  if dataset_cfg.auto_download:
206
  prepare_dataset(dataset_cfg, task)
207
 
208
- return YoloDataLoader(data_cfg, dataset_cfg, task)
209
 
210
 
211
  class StreamDataLoader:
 
12
  from rich.progress import track
13
  from torch import Tensor
14
  from torch.utils.data import DataLoader, Dataset
15
+ from torch.utils.data.distributed import DistributedSampler
16
 
17
  from yolo.config.config import DataConfig, DatasetConfig
18
  from yolo.tools.data_augmentation import (
 
157
 
158
 
159
  class YoloDataLoader(DataLoader):
160
+ def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train", use_ddp: bool = False):
161
  """Initializes the YoloDataLoader with hydra-config files."""
162
  dataset = YoloDataset(data_cfg, dataset_cfg, task)
163
+ sampler = DistributedSampler(dataset, shuffle=data_cfg.shuffle) if use_ddp else None
164
  self.image_size = data_cfg.image_size[0]
165
  super().__init__(
166
  dataset,
167
  batch_size=data_cfg.batch_size,
168
+ sampler=sampler,
169
+ shuffle=data_cfg.shuffle and not use_ddp,
170
  num_workers=data_cfg.cpu_num,
171
  pin_memory=data_cfg.pin_memory,
172
  collate_fn=self.collate_fn,
 
200
  return batch_images, batch_targets
201
 
202
 
203
+ def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train", use_ddp: bool = False):
204
  if task == "inference":
205
  return StreamDataLoader(data_cfg)
206
 
207
  if dataset_cfg.auto_download:
208
  prepare_dataset(dataset_cfg, task)
209
 
210
+ return YoloDataLoader(data_cfg, dataset_cfg, task, use_ddp)
211
 
212
 
213
  class StreamDataLoader:
yolo/tools/solver.py CHANGED
@@ -7,6 +7,8 @@ from torch import Tensor
7
 
8
  # TODO: We may can't use CUDA?
9
  from torch.cuda.amp import GradScaler, autocast
 
 
10
 
11
  from yolo.config.config import Config, TrainConfig, ValidationConfig
12
  from yolo.model.yolo import YOLO
@@ -25,7 +27,8 @@ from yolo.utils.model_utils import (
25
  class ModelTrainer:
26
  def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device):
27
  train_cfg: TrainConfig = cfg.task
28
- self.model = model
 
29
  self.vec2box = vec2box
30
  self.device = device
31
  self.optimizer = create_optimizer(model, train_cfg.optimizer)
@@ -86,13 +89,15 @@ class ModelTrainer:
86
  self.ema.restore()
87
  torch.save(checkpoint, filename)
88
 
89
- def solve(self, dataloader):
90
  logger.info("πŸš„ Start Training!")
91
  num_epochs = self.num_epochs
92
 
93
  with self.progress.progress:
94
  self.progress.start_train(num_epochs)
95
  for epoch in range(num_epochs):
 
 
96
 
97
  self.progress.start_one_epoch(len(dataloader), self.optimizer, epoch)
98
  epoch_loss = self.train_one_epoch(dataloader)
 
7
 
8
  # TODO: We may can't use CUDA?
9
  from torch.cuda.amp import GradScaler, autocast
10
+ from torch.nn.parallel import DistributedDataParallel as DDP
11
+ from torch.utils.data import DataLoader
12
 
13
  from yolo.config.config import Config, TrainConfig, ValidationConfig
14
  from yolo.model.yolo import YOLO
 
27
  class ModelTrainer:
28
  def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device):
29
  train_cfg: TrainConfig = cfg.task
30
+ self.model = model if not use_ddp else DDP(model, device_ids=[device])
31
+ self.use_ddp = use_ddp
32
  self.vec2box = vec2box
33
  self.device = device
34
  self.optimizer = create_optimizer(model, train_cfg.optimizer)
 
89
  self.ema.restore()
90
  torch.save(checkpoint, filename)
91
 
92
+ def solve(self, dataloader: DataLoader):
93
  logger.info("πŸš„ Start Training!")
94
  num_epochs = self.num_epochs
95
 
96
  with self.progress.progress:
97
  self.progress.start_train(num_epochs)
98
  for epoch in range(num_epochs):
99
+ if self.use_ddp:
100
+ dataloader.sampler.set_epoch(epoch)
101
 
102
  self.progress.start_one_epoch(len(dataloader), self.optimizer, epoch)
103
  epoch_loss = self.train_one_epoch(dataloader)
yolo/utils/model_utils.py CHANGED
@@ -1,7 +1,9 @@
1
- from typing import Any, Dict, List, Type, Union
 
2
 
3
  import torch
4
  import torch.distributed as dist
 
5
  from omegaconf import ListConfig
6
  from torch import nn
7
  from torch.nn.parallel import DistributedDataParallel as DDP
@@ -73,29 +75,21 @@ def create_scheduler(optimizer: Optimizer, schedule_cfg: SchedulerConfig) -> _LR
73
  return schedule
74
 
75
 
76
- def get_device():
77
- if torch.cuda.is_available():
78
- return torch.device("cuda")
79
- elif torch.backends.mps.is_available():
80
- return torch.device("mps")
81
- else:
82
- return torch.device("cpu")
83
-
84
-
85
- def send_to_device(model: nn.Module, device: Union[str, int, List[int]]):
86
- if not isinstance(device, (List, ListConfig)):
87
- device = torch.device(device)
88
- print("runing man")
89
- return device, model.to(device)
90
-
91
- device = torch.device("cuda")
92
- world_size = dist.get_world_size()
93
- print("runing man")
94
- dist.init_process_group(
95
- backend="gloo" if torch.cuda.is_available() else "gloo", rank=dist.get_rank(), world_size=world_size
96
- )
97
- print(f"Initialized process group; rank: {dist.get_rank()}, size: {world_size}")
98
-
99
- model = model.cuda(device)
100
- model = DDP(model, device_ids=[device])
101
- return device, model.to(device)
 
1
+ import os
2
+ from typing import List, Type, Union
3
 
4
  import torch
5
  import torch.distributed as dist
6
+ from loguru import logger
7
  from omegaconf import ListConfig
8
  from torch import nn
9
  from torch.nn.parallel import DistributedDataParallel as DDP
 
75
  return schedule
76
 
77
 
78
+ def initialize_distributed() -> None:
79
+ rank = int(os.getenv("RANK", "0"))
80
+ local_rank = int(os.getenv("LOCAL_RANK", "0"))
81
+ world_size = int(os.getenv("WORLD_SIZE", "1"))
82
+
83
+ torch.cuda.set_device(local_rank)
84
+ dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
85
+ logger.info(f"Initialized process group; rank: {rank}, size: {world_size}")
86
+ return local_rank
87
+
88
+
89
+ def get_device(device_spec: Union[str, int, List[int]]) -> torch.device:
90
+ ddp_flag = False
91
+ if isinstance(device_spec, (list, ListConfig)):
92
+ ddp_flag = True
93
+ device_spec = initialize_distributed()
94
+ device = torch.device(device_spec)
95
+ return device, ddp_flag