File size: 2,870 Bytes
3f9c425
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# @package _global_
defaults:
  - override /trainer: default # choose trainer from 'configs/trainer/'
  - override /model: null
  - override /datamodule: openwebtext
  # FusedAdam from apex speeds up the optimizer step a bit, for GPT2-small time
  # per global step (i.e. batch size 512) on 8 A100s goes from 376ms to 368ms.
  # For GPT2-medium time per global goes from 997ms to 972ms.
  - override /optimizer: adamw-apex
  - override /scheduler: linear-warmup
  - override /callbacks: [default, norm-monitor]
  - override /metrics: [perplexity, num-tokens]
  - override /logger: wandb

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

task:
  _target_: src.tasks.seq.SequenceLMModel

seed: 1111

trainer:
  accelerator: gpu
  devices: 8
  num_nodes: 1
  accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${datamodule.batch_size} * ${trainer.num_nodes}}}
  max_steps: 400000
  val_check_interval: ${eval:1000 * ${.accumulate_grad_batches}}
  check_val_every_n_epoch: null  # We don't care about epoch boundary
  precision: 16
  gradient_clip_val: 1.0
  strategy: null

datamodule:
  batch_size: 16  # Per GPU
  batch_size_eval: ${.batch_size}  # Fused dense only support batch size at most 64k
  max_length: 1024
  fault_tolerant: True
  ddp: ${eval:"${trainer.devices} > 1"}

train:
  gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"}
  global_batch_size: 512
  optimizer:
    lr: 6e-4
    weight_decay: 0.1
  optimizer_param_grouping:
    bias_weight_decay: False
    normalization_weight_decay: False
  scheduler:
    num_warmup_steps: ${eval:0.01 * ${trainer.max_steps}}
    num_training_steps: ${trainer.max_steps}
  loss_fn:
    # This is faster and uses less memory than torch.nn.CrossEntropyLoss.
    # It's also more numerically stable if we're using DeepSpeed 16 bits.
    _target_: flash_attn.losses.cross_entropy.CrossEntropyLoss
    inplace_backward: True  # to save memory

eval:
  log_on_step: True  # 1 training epoch takes too long, we want to see metrics per train step

callbacks:
  model_checkpoint:
    monitor: val/loss
    mode: min
    save_top_k: 3
    save_last: True
    every_n_train_steps: 1000
    dirpath: ${work_dir}/checkpoints/${oc.select:name,''}
    filename: step_{step}
    auto_insert_metric_name: False
  model_checkpoint_progress:
    _target_: src.callbacks.model_checkpoint.ModelCheckpointMine
    fault_tolerant: True
    every_n_train_steps: 50000
    save_last: False
    save_top_k: -1  # Save all the checkpoints
    dirpath: ${..model_checkpoint.dirpath}
    filename: progress_step_{step}
    auto_insert_metric_name: False
  early_stopping: null