boris commited on
Commit
728a3c3
·
unverified ·
1 Parent(s): 49be45e

feat: better multi-node support (#158)

Browse files

* reproducible data loader
* custom sharding
* model parallel across multiple nodes

src/dalle_mini/data.py CHANGED
@@ -43,6 +43,8 @@ class Dataset:
43
  if self.seed_dataset is None:
44
  # create a random seed
45
  self.seed_dataset = random.randint(0, 2**32 - 1)
 
 
46
  self.multi_hosts = jax.process_count() > 1
47
  # feed blank captions only in streaming mode for now
48
  # otherwise dataset could be cached with same blanked captions
@@ -173,6 +175,7 @@ class Dataset:
173
  blank_caption_function,
174
  text_column=self.text_column,
175
  blank_caption_prob=self.blank_caption_prob,
 
176
  )
177
  if hasattr(self, "train_dataset"):
178
  self.train_dataset = (
@@ -180,7 +183,9 @@ class Dataset:
180
  if self.streaming
181
  else self.train_dataset.map(
182
  partial_blank_caption_function,
183
- num_proc=self.preprocessing_num_workers,
 
 
184
  load_from_cache_file=False,
185
  desc="Blanking some captions",
186
  )
@@ -316,8 +321,12 @@ def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
316
  return shifted_input_ids
317
 
318
 
319
- def blank_caption_function(example, text_column, blank_caption_prob):
320
- if blank_caption_prob and np.random.rand() < blank_caption_prob:
 
 
 
 
321
  example[text_column] = ""
322
  return example
323
 
 
43
  if self.seed_dataset is None:
44
  # create a random seed
45
  self.seed_dataset = random.randint(0, 2**32 - 1)
46
+ # set numpy rng
47
+ self.np_rng = np.random.default_rng(self.seed_dataset)
48
  self.multi_hosts = jax.process_count() > 1
49
  # feed blank captions only in streaming mode for now
50
  # otherwise dataset could be cached with same blanked captions
 
175
  blank_caption_function,
176
  text_column=self.text_column,
177
  blank_caption_prob=self.blank_caption_prob,
178
+ rng=self.np_rng,
179
  )
180
  if hasattr(self, "train_dataset"):
181
  self.train_dataset = (
 
183
  if self.streaming
184
  else self.train_dataset.map(
185
  partial_blank_caption_function,
186
+ num_proc=None
187
+ if self.seed_dataset
188
+ else self.preprocessing_num_workers,
189
  load_from_cache_file=False,
190
  desc="Blanking some captions",
191
  )
 
321
  return shifted_input_ids
322
 
323
 
324
+ def blank_caption_function(example, text_column, blank_caption_prob, rng=None):
325
+ if (
326
+ blank_caption_prob
327
+ and (rng.random() if rng is not None else np.random.random())
328
+ < blank_caption_prob
329
+ ):
330
  example[text_column] = ""
331
  return example
332
 
tools/train/config/mega/config.json CHANGED
@@ -1,30 +1,49 @@
1
  {
2
  "activation_dropout": 0.0,
3
- "activation_function": "gelu",
4
  "attention_dropout": 0.0,
5
  "bos_token_id": 16385,
6
  "d_model": 2048,
7
  "decoder_attention_heads": 32,
8
- "decoder_ffn_dim": 8192,
9
  "decoder_layerdrop": 0.0,
10
- "decoder_layers": 24,
11
  "decoder_start_token_id": 16384,
 
12
  "dropout": 0.0,
13
  "encoder_attention_heads": 32,
14
- "encoder_ffn_dim": 8192,
15
  "encoder_layerdrop": 0.0,
16
- "encoder_layers": 24,
17
- "encoder_vocab_size": 50264,
18
  "eos_token_id": 16385,
 
 
19
  "image_length": 256,
20
- "image_vocab_size": 16391,
21
  "init_std": 0.01,
22
  "is_encoder_decoder": true,
 
 
 
23
  "max_text_length": 64,
 
24
  "model_type": "dallebart",
25
  "normalize_text": true,
26
  "pad_token_id": 16385,
27
  "scale_embedding": false,
 
 
28
  "tie_word_embeddings": false,
29
- "use_cache": true
 
 
 
 
 
 
 
 
 
 
30
  }
 
1
  {
2
  "activation_dropout": 0.0,
3
+ "activation_function": "swish",
4
  "attention_dropout": 0.0,
5
  "bos_token_id": 16385,
6
  "d_model": 2048,
7
  "decoder_attention_heads": 32,
8
+ "decoder_ffn_dim": 4096,
9
  "decoder_layerdrop": 0.0,
10
+ "decoder_layers": 25,
11
  "decoder_start_token_id": 16384,
12
+ "do_sample": true,
13
  "dropout": 0.0,
14
  "encoder_attention_heads": 32,
15
+ "encoder_ffn_dim": 4096,
16
  "encoder_layerdrop": 0.0,
17
+ "encoder_layers": 25,
18
+ "encoder_vocab_size": 50272,
19
  "eos_token_id": 16385,
20
+ "force_ln_scale": false,
21
+ "gradient_checkpointing": false,
22
  "image_length": 256,
23
+ "image_vocab_size": 16415,
24
  "init_std": 0.01,
25
  "is_encoder_decoder": true,
26
+ "ln_positions": "normformer",
27
+ "ln_type": "layernorm",
28
+ "max_length": 257,
29
  "max_text_length": 64,
30
+ "min_length": 257,
31
  "model_type": "dallebart",
32
  "normalize_text": true,
33
  "pad_token_id": 16385,
34
  "scale_embedding": false,
35
+ "sinkhorn_iters": 1,
36
+ "tau_init": 0.05,
37
  "tie_word_embeddings": false,
38
+ "use_absolute_position_embeddings": true,
39
+ "use_alibi": false,
40
+ "use_bias": false,
41
+ "use_cache": true,
42
+ "use_cosine_attention": false,
43
+ "use_deepnet_scaling": false,
44
+ "use_final_ln_decoder": true,
45
+ "use_final_ln_encoder": true,
46
+ "use_glu": true,
47
+ "use_head_scale": false,
48
+ "use_swin_position_embeddings": false
49
  }
tools/train/config/mini/config.json CHANGED
@@ -16,7 +16,7 @@
16
  "eos_token_id": 16385,
17
  "gradient_checkpointing": false,
18
  "image_length": 256,
19
- "image_vocab_size": 16384,
20
  "init_std": 0.02,
21
  "is_encoder_decoder": true,
22
  "max_text_length": 64,
 
16
  "eos_token_id": 16385,
17
  "gradient_checkpointing": false,
18
  "image_length": 256,
19
+ "image_vocab_size": 16391,
20
  "init_std": 0.02,
21
  "is_encoder_decoder": true,
22
  "max_text_length": 64,
tools/train/train.py CHANGED
@@ -368,6 +368,12 @@ class TrainingArguments:
368
  "help": "Whether to quantize optimizer (only supported with Distributed Shampoo)."
369
  },
370
  )
 
 
 
 
 
 
371
 
372
  num_train_epochs: int = field(
373
  default=3, metadata={"help": "Total number of training epochs to perform."}
@@ -450,6 +456,11 @@ class TrainingArguments:
450
  metadata={"help": "Verify that TPU is not in use."},
451
  )
452
 
 
 
 
 
 
453
  mp_devices: Optional[int] = field(
454
  default=1,
455
  metadata={
@@ -500,6 +511,11 @@ class TrainingArguments:
500
  f"Output directory ({self.output_dir}) already exists and is not empty."
501
  "Use --overwrite_output_dir to overcome."
502
  )
 
 
 
 
 
503
  assert (
504
  self.mp_devices > 0
505
  ), f"Number of devices for model parallelism must be > 0"
@@ -530,6 +546,12 @@ def main():
530
  else:
531
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
532
 
 
 
 
 
 
 
533
  # Make one log on every process with the configuration for debugging.
534
  logging.basicConfig(
535
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -748,8 +770,20 @@ def main():
748
  graft_type=graft_type,
749
  nesterov=False,
750
  exponent_override=0,
751
- statistics_partition_spec=PartitionSpec(None, "dp", None),
752
- preconditioner_partition_spec=PartitionSpec("dp", None, None),
 
 
 
 
 
 
 
 
 
 
 
 
753
  num_devices_for_pjit=training_args.dp_devices,
754
  shard_optimizer_states=True,
755
  inverse_failure_threshold=0.1,
@@ -917,7 +951,7 @@ def main():
917
 
918
  # "vmap trick" avoids a crash when mp_devices > 1 (not sure why it happens)
919
  # lead to better perf: see https://wandb.ai/dalle-mini/dalle-mini/reports/JAX-pmap-vs-pjit--VmlldzoxNDg1ODA2
920
- use_vmap_trick = True
921
 
922
  # make grad_param_spec for vmap
923
  if use_vmap_trick:
@@ -1145,7 +1179,8 @@ def main():
1145
  self.log_time("train_per_log", delta_time, offset=False)
1146
 
1147
  def log_time(self, key, duration, offset=True):
1148
- wandb.log({f"time/{key}": duration, **self.state_dict})
 
1149
  if offset:
1150
  self.offset_time += duration
1151
 
@@ -1191,7 +1226,11 @@ def main():
1191
  # ======================== Evaluating ==============================
1192
  if training_args.do_eval:
1193
  start_eval_time = time.perf_counter()
1194
- eval_loader = dataset.dataloader("eval", eval_batch_size_per_step)
 
 
 
 
1195
  eval_steps = (
1196
  len_eval_dataset // eval_batch_size_per_step
1197
  if len_eval_dataset is not None
@@ -1353,10 +1392,12 @@ def main():
1353
  metrics_logger.update_state_metrics(local_state)
1354
  metrics_logger.log({})
1355
 
1356
- # Generate an epoch by shuffling sampling indices from the train dataset
 
 
1357
  train_loader = dataset.dataloader(
1358
  "train",
1359
- batch_size_per_node,
1360
  epoch,
1361
  )
1362
  # train
@@ -1373,12 +1414,12 @@ def main():
1373
 
1374
  # set correct shape to batch
1375
  # - add grad_step dim if gradient_accumulation_steps > 1
1376
- # - split per dp device if not multi-host for vmap trick (does not work in multi-host)
1377
  bs_shape = (
1378
- (batch_size_per_node_per_grad_step,)
1379
  if not use_vmap_trick
1380
  else (
1381
  jax.local_device_count()
 
1382
  // training_args.mp_devices, # local dp devices
1383
  training_args.per_device_train_batch_size,
1384
  )
 
368
  "help": "Whether to quantize optimizer (only supported with Distributed Shampoo)."
369
  },
370
  )
371
+ shard_shampoo_across: str = field(
372
+ default="dp",
373
+ metadata={
374
+ "help": "Whether to shard the optimizer across data devices (dp), model devices (mp) or both (2d)."
375
+ },
376
+ )
377
 
378
  num_train_epochs: int = field(
379
  default=3, metadata={"help": "Total number of training epochs to perform."}
 
456
  metadata={"help": "Verify that TPU is not in use."},
457
  )
458
 
459
+ use_vmap_trick: bool = field(
460
+ default=True,
461
+ metadata={"help": "Verify that TPU is not in use."},
462
+ )
463
+
464
  mp_devices: Optional[int] = field(
465
  default=1,
466
  metadata={
 
511
  f"Output directory ({self.output_dir}) already exists and is not empty."
512
  "Use --overwrite_output_dir to overcome."
513
  )
514
+ assert self.shard_shampoo_across in [
515
+ "dp",
516
+ "mp",
517
+ "2d",
518
+ ], f"Shard shampoo across {self.shard_shampoo_across} not supported."
519
  assert (
520
  self.mp_devices > 0
521
  ), f"Number of devices for model parallelism must be > 0"
 
546
  else:
547
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
548
 
549
+ # check arguments
550
+ if training_args.mp_devices > jax.local_device_count():
551
+ assert (
552
+ data_args.seed_dataset is not None
553
+ ), "Seed dataset must be provided when model is split over multiple hosts"
554
+
555
  # Make one log on every process with the configuration for debugging.
556
  logging.basicConfig(
557
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
 
770
  graft_type=graft_type,
771
  nesterov=False,
772
  exponent_override=0,
773
+ statistics_partition_spec=PartitionSpec(
774
+ None, training_args.shard_shampoo_across, None
775
+ )
776
+ if training_args.shard_shampoo_across != "2d"
777
+ else PartitionSpec(None, "dp", "mp"),
778
+ preconditioner_partition_spec=PartitionSpec(
779
+ training_args.shard_shampoo_across, None, None
780
+ )
781
+ if training_args.shard_shampoo_across != "2d"
782
+ else PartitionSpec(
783
+ "mp" if training_args.mp_devices > training_args.dp_devices else "dp",
784
+ None,
785
+ None,
786
+ ),
787
  num_devices_for_pjit=training_args.dp_devices,
788
  shard_optimizer_states=True,
789
  inverse_failure_threshold=0.1,
 
951
 
952
  # "vmap trick" avoids a crash when mp_devices > 1 (not sure why it happens)
953
  # lead to better perf: see https://wandb.ai/dalle-mini/dalle-mini/reports/JAX-pmap-vs-pjit--VmlldzoxNDg1ODA2
954
+ use_vmap_trick = training_args.use_vmap_trick
955
 
956
  # make grad_param_spec for vmap
957
  if use_vmap_trick:
 
1179
  self.log_time("train_per_log", delta_time, offset=False)
1180
 
1181
  def log_time(self, key, duration, offset=True):
1182
+ if jax.process_index() == 0:
1183
+ wandb.log({f"time/{key}": duration, **self.state_dict})
1184
  if offset:
1185
  self.offset_time += duration
1186
 
 
1226
  # ======================== Evaluating ==============================
1227
  if training_args.do_eval:
1228
  start_eval_time = time.perf_counter()
1229
+ eval_loader = dataset.dataloader(
1230
+ "eval",
1231
+ eval_batch_size_per_step
1232
+ * max(1, training_args.mp_devices // jax.local_device_count()),
1233
+ )
1234
  eval_steps = (
1235
  len_eval_dataset // eval_batch_size_per_step
1236
  if len_eval_dataset is not None
 
1392
  metrics_logger.update_state_metrics(local_state)
1393
  metrics_logger.log({})
1394
 
1395
+ # load data - may be replicated on multiple nodes
1396
+ node_groups = max(1, training_args.mp_devices // jax.local_device_count())
1397
+ loader_bs = batch_size_per_node * node_groups
1398
  train_loader = dataset.dataloader(
1399
  "train",
1400
+ loader_bs,
1401
  epoch,
1402
  )
1403
  # train
 
1414
 
1415
  # set correct shape to batch
1416
  # - add grad_step dim if gradient_accumulation_steps > 1
 
1417
  bs_shape = (
1418
+ (batch_size_per_node_per_grad_step * node_groups,)
1419
  if not use_vmap_trick
1420
  else (
1421
  jax.local_device_count()
1422
+ * node_groups
1423
  // training_args.mp_devices, # local dp devices
1424
  training_args.per_device_train_batch_size,
1425
  )