glenn-jocher pre-commit-ci[bot] commited on
Commit
d95a728
·
unverified ·
1 Parent(s): f3fecf9

Implement DDP `static_graph=True` (#6940)

Browse files

* Implement DDP `static_graph=True`

Experimental implementation of new PyTorch 1.11.0 DDP feature.

* Add 1.11.0 check

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Files changed (1) hide show
  1. train.py +7 -4
train.py CHANGED
@@ -47,9 +47,9 @@ from utils.callbacks import Callbacks
47
  from utils.datasets import create_dataloader
48
  from utils.downloads import attempt_download
49
  from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
50
- check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds,
51
- intersect_dicts, is_ascii, labels_to_class_weights, labels_to_image_weights, methods,
52
- one_cycle, print_args, print_mutation, strip_optimizer)
53
  from utils.loggers import Loggers
54
  from utils.loggers.wandb.wandb_utils import check_wandb_resume
55
  from utils.loss import ComputeLoss
@@ -269,7 +269,10 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
269
 
270
  # DDP mode
271
  if cuda and RANK != -1:
272
- model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
 
 
 
273
 
274
  # Model attributes
275
  nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)
 
47
  from utils.datasets import create_dataloader
48
  from utils.downloads import attempt_download
49
  from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
50
+ check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path,
51
+ init_seeds, intersect_dicts, is_ascii, labels_to_class_weights, labels_to_image_weights,
52
+ methods, one_cycle, print_args, print_mutation, strip_optimizer)
53
  from utils.loggers import Loggers
54
  from utils.loggers.wandb.wandb_utils import check_wandb_resume
55
  from utils.loss import ComputeLoss
 
269
 
270
  # DDP mode
271
  if cuda and RANK != -1:
272
+ if check_version(torch.__version__, '1.11.0'):
273
+ model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
274
+ else:
275
+ model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
276
 
277
  # Model attributes
278
  nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)