henry000 commited on
Commit
7d7e199
Β·
1 Parent(s): bce644c

🚧 [WIP] DDP for model training

Browse files
yolo/lazy.py CHANGED
@@ -14,19 +14,18 @@ from yolo.tools.solver import ModelTester, ModelTrainer
14
  from yolo.utils.bounding_box_utils import Vec2Box
15
  from yolo.utils.deploy_utils import FastModelLoader
16
  from yolo.utils.logging_utils import ProgressLogger
 
17
 
18
 
19
  @hydra.main(config_path="config", config_name="config", version_base=None)
20
  def main(cfg: Config):
21
  progress = ProgressLogger(cfg, exp_name=cfg.name)
22
  dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
23
- device = torch.device(cfg.device)
24
  if getattr(cfg.task, "fast_inference", False):
25
- model = FastModelLoader(cfg, device).load_model()
26
- device = torch.device(cfg.device)
27
  else:
28
- model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight, device=device)
29
-
30
  vec2box = Vec2Box(model, cfg.image_size, device)
31
 
32
  if cfg.task.task == "train":
 
14
  from yolo.utils.bounding_box_utils import Vec2Box
15
  from yolo.utils.deploy_utils import FastModelLoader
16
  from yolo.utils.logging_utils import ProgressLogger
17
+ from yolo.utils.model_utils import send_to_device
18
 
19
 
20
  @hydra.main(config_path="config", config_name="config", version_base=None)
21
  def main(cfg: Config):
22
  progress = ProgressLogger(cfg, exp_name=cfg.name)
23
  dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
 
24
  if getattr(cfg.task, "fast_inference", False):
25
+ model = FastModelLoader(cfg).load_model()
 
26
  else:
27
+ model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight)
28
+ device, model = send_to_device(model, cfg.device)
29
  vec2box = Vec2Box(model, cfg.image_size, device)
30
 
31
  if cfg.task.task == "train":
yolo/model/yolo.py CHANGED
@@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Union
4
  import torch
5
  from loguru import logger
6
  from omegaconf import ListConfig, OmegaConf
7
- from torch import device, nn
8
 
9
  from yolo.config.config import Config, ModelConfig, YOLOLayer
10
  from yolo.tools.dataset_preparation import prepare_weight
@@ -117,7 +117,7 @@ class YOLO(nn.Module):
117
  raise ValueError(f"Unsupported layer type: {layer_type}")
118
 
119
 
120
- def create_model(model_cfg: ModelConfig, weight_path: Optional[str], device: device, class_num: int = 80) -> YOLO:
121
  """Constructs and returns a model from a Dictionary configuration file.
122
 
123
  Args:
@@ -134,9 +134,10 @@ def create_model(model_cfg: ModelConfig, weight_path: Optional[str], device: dev
134
  logger.info(f"🌐 Weight {weight_path} not found, try downloading")
135
  prepare_weight(weight_path=weight_path)
136
  if os.path.exists(weight_path):
137
- model.model.load_state_dict(torch.load(weight_path, map_location=device), strict=False)
 
138
  logger.info("βœ… Success load model weight")
139
 
140
  log_model_structure(model.model)
141
  draw_model(model=model)
142
- return model.to(device)
 
4
  import torch
5
  from loguru import logger
6
  from omegaconf import ListConfig, OmegaConf
7
+ from torch import nn
8
 
9
  from yolo.config.config import Config, ModelConfig, YOLOLayer
10
  from yolo.tools.dataset_preparation import prepare_weight
 
117
  raise ValueError(f"Unsupported layer type: {layer_type}")
118
 
119
 
120
+ def create_model(model_cfg: ModelConfig, weight_path: Optional[str], class_num: int = 80) -> YOLO:
121
  """Constructs and returns a model from a Dictionary configuration file.
122
 
123
  Args:
 
134
  logger.info(f"🌐 Weight {weight_path} not found, try downloading")
135
  prepare_weight(weight_path=weight_path)
136
  if os.path.exists(weight_path):
137
+ # TODO: fix map_location
138
+ model.model.load_state_dict(torch.load(weight_path), strict=False)
139
  logger.info("βœ… Success load model weight")
140
 
141
  log_model_structure(model.model)
142
  draw_model(model=model)
143
+ return model
yolo/utils/deploy_utils.py CHANGED
@@ -9,9 +9,8 @@ from yolo.model.yolo import create_model
9
 
10
 
11
  class FastModelLoader:
12
- def __init__(self, cfg: Config, device):
13
  self.cfg = cfg
14
- self.device = device
15
  self.compiler = cfg.task.fast_inference
16
  self._validate_compiler()
17
  self.model_path = f"{os.path.splitext(cfg.weight)[0]}.{self.compiler}"
 
9
 
10
 
11
  class FastModelLoader:
12
+ def __init__(self, cfg: Config):
13
  self.cfg = cfg
 
14
  self.compiler = cfg.task.fast_inference
15
  self._validate_compiler()
16
  self.model_path = f"{os.path.splitext(cfg.weight)[0]}.{self.compiler}"
yolo/utils/model_utils.py CHANGED
@@ -1,6 +1,10 @@
1
- from typing import Any, Dict, Type
2
 
3
  import torch
 
 
 
 
4
  from torch.optim import Optimizer
5
  from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler
6
 
@@ -67,3 +71,31 @@ def create_scheduler(optimizer: Optimizer, schedule_cfg: SchedulerConfig) -> _LR
67
  warmup_schedule = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2, lambda1])
68
  schedule = SequentialLR(optimizer, schedulers=[warmup_schedule, schedule], milestones=[2])
69
  return schedule
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Type, Union
2
 
3
  import torch
4
+ import torch.distributed as dist
5
+ from omegaconf import ListConfig
6
+ from torch import nn
7
+ from torch.nn.parallel import DistributedDataParallel as DDP
8
  from torch.optim import Optimizer
9
  from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler
10
 
 
71
  warmup_schedule = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2, lambda1])
72
  schedule = SequentialLR(optimizer, schedulers=[warmup_schedule, schedule], milestones=[2])
73
  return schedule
74
+
75
+
76
+ def get_device():
77
+ if torch.cuda.is_available():
78
+ return torch.device("cuda")
79
+ elif torch.backends.mps.is_available():
80
+ return torch.device("mps")
81
+ else:
82
+ return torch.device("cpu")
83
+
84
+
85
+ def send_to_device(model: nn.Module, device: Union[str, int, List[int]]):
86
+ if not isinstance(device, (List, ListConfig)):
87
+ device = torch.device(device)
88
+ print("runing man")
89
+ return device, model.to(device)
90
+
91
+ device = torch.device("cuda")
92
+ world_size = dist.get_world_size()
93
+ print("runing man")
94
+ dist.init_process_group(
95
+ backend="gloo" if torch.cuda.is_available() else "gloo", rank=dist.get_rank(), world_size=world_size
96
+ )
97
+ print(f"Initialized process group; rank: {dist.get_rank()}, size: {world_size}")
98
+
99
+ model = model.cuda(device)
100
+ model = DDP(model, device_ids=[device])
101
+ return device, model.to(device)