File size: 4,706 Bytes
d061944
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
defaults:
  - _self_
  - /callbacks: [checkpoint_every_n_steps, checkpoint_monitor, learning_rate_monitor]
  - /model: small
  - /strategy: ddp
  - /noise: loglinear
  - /lr_scheduler: constant_warmup

mode: sample_eval  # train / ppl_eval / sample_eval
diffusion: absorbing_state
backbone: membrane_esm_finetune  # dit / dimamba / ar / vanilla_esm_pretrain / membrane_esm_finetune
parameterization: subs  # subs / d3pm / sedd
time_conditioning: False
T: 0  # 0 (continuous time) / 1000 
subs_masking: False

seed: 42

data:
  train:
    vanilla_esm_train_path: /workspace/sg666/MDpLM/data/uniref50/200k_seqs/train.csv
    membrane_esm_train_path: /workspace/sg666/MDpLM/data/membrane/train.csv
    wrap: null
  test:
    vanilla_esm_test_path: /workspace/sg666/MDpLM/data/uniref50/200k_seqs/test.csv
    membrane_esm_test_path: /workspace/sg666/MDpLM/data/membrane/test.csv
    wrap: null
  valid:
    vanilla_esm_valid_path: /workspace/sg666/MDpLM/data/uniref50/200k_seqs/val.csv
    membrane_esm_valid_path: /workspace/sg666/MDpLM/data/membrane/val.csv
    wrap: null
  wrapping: True

loader:
  global_batch_size: 8
  eval_global_batch_size: ${.global_batch_size}
  # Note: batch_size and eval_batch_size are **per machine**
  batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
  eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
  num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
  pin_memory: True

sampling:
  predictor: ddpm_cache  # analytic, ddpm, ddpm_cache
  steps: 128
  noise_removal: True
  # TODO(yair): @subham, why aren't these params under `eval`?
  num_sample_batches: 2  # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
  num_sample_log: 2
  semi_ar: False
  stride_length: 1
  num_strides: 1

training:
  ema: 0.9999
  antithetic_sampling: True
  importance_sampling: False
  sampling_eps: 1e-3
  change_of_variables: False
  mlm_model_path: /workspace/sg666/MDpLM/benchmarks/MLM/model_ckpts_650M/best_model_epoch
  esm_model_path: facebook/esm2_t30_150M_UR50D
  focus_mask: False

eval:
  checkpoint_path: /workspace/sg666/MDpLM/checkpoints/membrane_mdlm/eos-wrapping_epochs60_lr3e-4_200k-seqs_bsz16_all-params_no-compile_gradclip1_beta-one0.9_beta-two0.999_bf16/checkpoints/best.ckpt # Used to evaluate a checkpoint after training.
  disable_ema: False
  compute_generative_perplexity: False
  perplexity_batch_size: 8
  compute_perplexity_on_sanity: False
  gen_ppl_eval_model_name_or_path: gpt2-large  # gpt2-large, meta-llama/Llama-2-7b-hf
  generate_samples: True
  generation_model: /workspace/sg666/MDpLM/checkpoints/membrane_automodel/epochs60_lr3e-4_200k-seqs_bsz16_all-params_no-compile_gradclip1_beta-one0.9_beta-two0.999_bf16/

optim:
  weight_decay: 0.075
  lr: 3e-4
  beta1: 0.9
  beta2: 0.999
  eps: 1e-8

Model:
  hidden_size: 1280
  cond_dim: 256
  n_heads: 20
  n_blocks: 4
  dropout: 0.5
  length: null #512
  scale_by_sigma: True

trainer:
  _target_: lightning.Trainer
  accelerator: cuda
  num_nodes: 1
  devices: ${device_count:}
  accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
  gradient_clip_val: 1.0
  precision: bf16
  num_sanity_val_steps: 2
  max_epochs: 60
  max_steps: 1_000_000
  log_every_n_steps: 10
  limit_train_batches: 1.0   # train on full dataset, can be used to toggle quick run
  limit_val_batches: 1.0     # validate on full dataset, can be used to toggle quick run
  val_check_interval: 955

wandb:
  project: MDpLM_finetune_membrane_200k-seqs
  notes: null
  group: programmablebio
  job_type: null
  name: dit_test #dit_wrapping_epochs60_lr3e-4_200k-seqs_bsz16_all-params_no-compile_gradclip1_beta-one0.9_beta-two0.999_bf16
  id: ${.name}_${seed}

hydra:
  run:
    dir: /workspace/sg666/MDpLM/outputs/${data.train}/${now:%Y.%m.%d}/${now:%H%M%S}
  job:
    chdir: true

checkpointing:
  # Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
  save_dir: /workspace/sg666/MDpLM/checkpoints/membrane_mdlm/
  # Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
  resume_from_ckpt: false
  resume_ckpt_path: ${.save_dir}/epochs30_lr3e-4_bsz8_gradclip1_beta-one0.9_beta-two0.999_bf16_all-params_no-compile/checkpoints/last.ckpt #/checkpoints/last.ckpt
  pretrained_esm_mdlm_automodel_path: /workspace/sg666/MDpLM/checkpoints/vanilla_esm_pretrained_automodel/epochs10_lr3e-4_200k-seqs_bsz16_all-params_no-compile_gradclip1_beta-one0.9_beta-two0.999_bf16/
  finetuned_esm_mdlm_automodel_path: /workspace/sg666/MDpLM/checkpoints/membrane_mdlm/