Implement DDP `static_graph=True` (#6940)
Browse files* Implement DDP `static_graph=True`
Experimental implementation of new PyTorch 1.11.0 DDP feature.
* Add 1.11.0 check
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
train.py
CHANGED
@@ -47,9 +47,9 @@ from utils.callbacks import Callbacks
|
|
47 |
from utils.datasets import create_dataloader
|
48 |
from utils.downloads import attempt_download
|
49 |
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
|
50 |
-
check_suffix, check_yaml, colorstr, get_latest_run, increment_path,
|
51 |
-
intersect_dicts, is_ascii, labels_to_class_weights, labels_to_image_weights,
|
52 |
-
one_cycle, print_args, print_mutation, strip_optimizer)
|
53 |
from utils.loggers import Loggers
|
54 |
from utils.loggers.wandb.wandb_utils import check_wandb_resume
|
55 |
from utils.loss import ComputeLoss
|
@@ -269,7 +269,10 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
|
|
269 |
|
270 |
# DDP mode
|
271 |
if cuda and RANK != -1:
|
272 |
-
|
|
|
|
|
|
|
273 |
|
274 |
# Model attributes
|
275 |
nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)
|
|
|
47 |
from utils.datasets import create_dataloader
|
48 |
from utils.downloads import attempt_download
|
49 |
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
|
50 |
+
check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path,
|
51 |
+
init_seeds, intersect_dicts, is_ascii, labels_to_class_weights, labels_to_image_weights,
|
52 |
+
methods, one_cycle, print_args, print_mutation, strip_optimizer)
|
53 |
from utils.loggers import Loggers
|
54 |
from utils.loggers.wandb.wandb_utils import check_wandb_resume
|
55 |
from utils.loss import ComputeLoss
|
|
|
269 |
|
270 |
# DDP mode
|
271 |
if cuda and RANK != -1:
|
272 |
+
if check_version(torch.__version__, '1.11.0'):
|
273 |
+
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
|
274 |
+
else:
|
275 |
+
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
|
276 |
|
277 |
# Model attributes
|
278 |
nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)
|