π§ [WIP] DDP for model training
Browse files- yolo/lazy.py +4 -5
- yolo/model/yolo.py +5 -4
- yolo/utils/deploy_utils.py +1 -2
- yolo/utils/model_utils.py +33 -1
yolo/lazy.py
CHANGED
@@ -14,19 +14,18 @@ from yolo.tools.solver import ModelTester, ModelTrainer
|
|
14 |
from yolo.utils.bounding_box_utils import Vec2Box
|
15 |
from yolo.utils.deploy_utils import FastModelLoader
|
16 |
from yolo.utils.logging_utils import ProgressLogger
|
|
|
17 |
|
18 |
|
19 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
20 |
def main(cfg: Config):
|
21 |
progress = ProgressLogger(cfg, exp_name=cfg.name)
|
22 |
dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
|
23 |
-
device = torch.device(cfg.device)
|
24 |
if getattr(cfg.task, "fast_inference", False):
|
25 |
-
model = FastModelLoader(cfg
|
26 |
-
device = torch.device(cfg.device)
|
27 |
else:
|
28 |
-
model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight
|
29 |
-
|
30 |
vec2box = Vec2Box(model, cfg.image_size, device)
|
31 |
|
32 |
if cfg.task.task == "train":
|
|
|
14 |
from yolo.utils.bounding_box_utils import Vec2Box
|
15 |
from yolo.utils.deploy_utils import FastModelLoader
|
16 |
from yolo.utils.logging_utils import ProgressLogger
|
17 |
+
from yolo.utils.model_utils import send_to_device
|
18 |
|
19 |
|
20 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
21 |
def main(cfg: Config):
|
22 |
progress = ProgressLogger(cfg, exp_name=cfg.name)
|
23 |
dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
|
|
|
24 |
if getattr(cfg.task, "fast_inference", False):
|
25 |
+
model = FastModelLoader(cfg).load_model()
|
|
|
26 |
else:
|
27 |
+
model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight)
|
28 |
+
device, model = send_to_device(model, cfg.device)
|
29 |
vec2box = Vec2Box(model, cfg.image_size, device)
|
30 |
|
31 |
if cfg.task.task == "train":
|
yolo/model/yolo.py
CHANGED
@@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Union
|
|
4 |
import torch
|
5 |
from loguru import logger
|
6 |
from omegaconf import ListConfig, OmegaConf
|
7 |
-
from torch import
|
8 |
|
9 |
from yolo.config.config import Config, ModelConfig, YOLOLayer
|
10 |
from yolo.tools.dataset_preparation import prepare_weight
|
@@ -117,7 +117,7 @@ class YOLO(nn.Module):
|
|
117 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
118 |
|
119 |
|
120 |
-
def create_model(model_cfg: ModelConfig, weight_path: Optional[str],
|
121 |
"""Constructs and returns a model from a Dictionary configuration file.
|
122 |
|
123 |
Args:
|
@@ -134,9 +134,10 @@ def create_model(model_cfg: ModelConfig, weight_path: Optional[str], device: dev
|
|
134 |
logger.info(f"π Weight {weight_path} not found, try downloading")
|
135 |
prepare_weight(weight_path=weight_path)
|
136 |
if os.path.exists(weight_path):
|
137 |
-
|
|
|
138 |
logger.info("β
Success load model weight")
|
139 |
|
140 |
log_model_structure(model.model)
|
141 |
draw_model(model=model)
|
142 |
-
return model
|
|
|
4 |
import torch
|
5 |
from loguru import logger
|
6 |
from omegaconf import ListConfig, OmegaConf
|
7 |
+
from torch import nn
|
8 |
|
9 |
from yolo.config.config import Config, ModelConfig, YOLOLayer
|
10 |
from yolo.tools.dataset_preparation import prepare_weight
|
|
|
117 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
118 |
|
119 |
|
120 |
+
def create_model(model_cfg: ModelConfig, weight_path: Optional[str], class_num: int = 80) -> YOLO:
|
121 |
"""Constructs and returns a model from a Dictionary configuration file.
|
122 |
|
123 |
Args:
|
|
|
134 |
logger.info(f"π Weight {weight_path} not found, try downloading")
|
135 |
prepare_weight(weight_path=weight_path)
|
136 |
if os.path.exists(weight_path):
|
137 |
+
# TODO: fix map_location
|
138 |
+
model.model.load_state_dict(torch.load(weight_path), strict=False)
|
139 |
logger.info("β
Success load model weight")
|
140 |
|
141 |
log_model_structure(model.model)
|
142 |
draw_model(model=model)
|
143 |
+
return model
|
yolo/utils/deploy_utils.py
CHANGED
@@ -9,9 +9,8 @@ from yolo.model.yolo import create_model
|
|
9 |
|
10 |
|
11 |
class FastModelLoader:
|
12 |
-
def __init__(self, cfg: Config
|
13 |
self.cfg = cfg
|
14 |
-
self.device = device
|
15 |
self.compiler = cfg.task.fast_inference
|
16 |
self._validate_compiler()
|
17 |
self.model_path = f"{os.path.splitext(cfg.weight)[0]}.{self.compiler}"
|
|
|
9 |
|
10 |
|
11 |
class FastModelLoader:
|
12 |
+
def __init__(self, cfg: Config):
|
13 |
self.cfg = cfg
|
|
|
14 |
self.compiler = cfg.task.fast_inference
|
15 |
self._validate_compiler()
|
16 |
self.model_path = f"{os.path.splitext(cfg.weight)[0]}.{self.compiler}"
|
yolo/utils/model_utils.py
CHANGED
@@ -1,6 +1,10 @@
|
|
1 |
-
from typing import Any, Dict, Type
|
2 |
|
3 |
import torch
|
|
|
|
|
|
|
|
|
4 |
from torch.optim import Optimizer
|
5 |
from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler
|
6 |
|
@@ -67,3 +71,31 @@ def create_scheduler(optimizer: Optimizer, schedule_cfg: SchedulerConfig) -> _LR
|
|
67 |
warmup_schedule = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2, lambda1])
|
68 |
schedule = SequentialLR(optimizer, schedulers=[warmup_schedule, schedule], milestones=[2])
|
69 |
return schedule
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Type, Union
|
2 |
|
3 |
import torch
|
4 |
+
import torch.distributed as dist
|
5 |
+
from omegaconf import ListConfig
|
6 |
+
from torch import nn
|
7 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
8 |
from torch.optim import Optimizer
|
9 |
from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler
|
10 |
|
|
|
71 |
warmup_schedule = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2, lambda1])
|
72 |
schedule = SequentialLR(optimizer, schedulers=[warmup_schedule, schedule], milestones=[2])
|
73 |
return schedule
|
74 |
+
|
75 |
+
|
76 |
+
def get_device():
|
77 |
+
if torch.cuda.is_available():
|
78 |
+
return torch.device("cuda")
|
79 |
+
elif torch.backends.mps.is_available():
|
80 |
+
return torch.device("mps")
|
81 |
+
else:
|
82 |
+
return torch.device("cpu")
|
83 |
+
|
84 |
+
|
85 |
+
def send_to_device(model: nn.Module, device: Union[str, int, List[int]]):
|
86 |
+
if not isinstance(device, (List, ListConfig)):
|
87 |
+
device = torch.device(device)
|
88 |
+
print("runing man")
|
89 |
+
return device, model.to(device)
|
90 |
+
|
91 |
+
device = torch.device("cuda")
|
92 |
+
world_size = dist.get_world_size()
|
93 |
+
print("runing man")
|
94 |
+
dist.init_process_group(
|
95 |
+
backend="gloo" if torch.cuda.is_available() else "gloo", rank=dist.get_rank(), world_size=world_size
|
96 |
+
)
|
97 |
+
print(f"Initialized process group; rank: {dist.get_rank()}, size: {world_size}")
|
98 |
+
|
99 |
+
model = model.cuda(device)
|
100 |
+
model = DDP(model, device_ids=[device])
|
101 |
+
return device, model.to(device)
|