MeMDLM / config.yaml
sgoel30's picture
Upload 12 files
d061944 verified
raw
history blame
4.71 kB
defaults:
- _self_
- /callbacks: [checkpoint_every_n_steps, checkpoint_monitor, learning_rate_monitor]
- /model: small
- /strategy: ddp
- /noise: loglinear
- /lr_scheduler: constant_warmup
mode: sample_eval # train / ppl_eval / sample_eval
diffusion: absorbing_state
backbone: membrane_esm_finetune # dit / dimamba / ar / vanilla_esm_pretrain / membrane_esm_finetune
parameterization: subs # subs / d3pm / sedd
time_conditioning: False
T: 0 # 0 (continuous time) / 1000
subs_masking: False
seed: 42
data:
train:
vanilla_esm_train_path: /workspace/sg666/MDpLM/data/uniref50/200k_seqs/train.csv
membrane_esm_train_path: /workspace/sg666/MDpLM/data/membrane/train.csv
wrap: null
test:
vanilla_esm_test_path: /workspace/sg666/MDpLM/data/uniref50/200k_seqs/test.csv
membrane_esm_test_path: /workspace/sg666/MDpLM/data/membrane/test.csv
wrap: null
valid:
vanilla_esm_valid_path: /workspace/sg666/MDpLM/data/uniref50/200k_seqs/val.csv
membrane_esm_valid_path: /workspace/sg666/MDpLM/data/membrane/val.csv
wrap: null
wrapping: True
loader:
global_batch_size: 8
eval_global_batch_size: ${.global_batch_size}
# Note: batch_size and eval_batch_size are **per machine**
batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
pin_memory: True
sampling:
predictor: ddpm_cache # analytic, ddpm, ddpm_cache
steps: 128
noise_removal: True
# TODO(yair): @subham, why aren't these params under `eval`?
num_sample_batches: 2 # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
num_sample_log: 2
semi_ar: False
stride_length: 1
num_strides: 1
training:
ema: 0.9999
antithetic_sampling: True
importance_sampling: False
sampling_eps: 1e-3
change_of_variables: False
mlm_model_path: /workspace/sg666/MDpLM/benchmarks/MLM/model_ckpts_650M/best_model_epoch
esm_model_path: facebook/esm2_t30_150M_UR50D
focus_mask: False
eval:
checkpoint_path: /workspace/sg666/MDpLM/checkpoints/membrane_mdlm/eos-wrapping_epochs60_lr3e-4_200k-seqs_bsz16_all-params_no-compile_gradclip1_beta-one0.9_beta-two0.999_bf16/checkpoints/best.ckpt # Used to evaluate a checkpoint after training.
disable_ema: False
compute_generative_perplexity: False
perplexity_batch_size: 8
compute_perplexity_on_sanity: False
gen_ppl_eval_model_name_or_path: gpt2-large # gpt2-large, meta-llama/Llama-2-7b-hf
generate_samples: True
generation_model: /workspace/sg666/MDpLM/checkpoints/membrane_automodel/epochs60_lr3e-4_200k-seqs_bsz16_all-params_no-compile_gradclip1_beta-one0.9_beta-two0.999_bf16/
optim:
weight_decay: 0.075
lr: 3e-4
beta1: 0.9
beta2: 0.999
eps: 1e-8
Model:
hidden_size: 1280
cond_dim: 256
n_heads: 20
n_blocks: 4
dropout: 0.5
length: null #512
scale_by_sigma: True
trainer:
_target_: lightning.Trainer
accelerator: cuda
num_nodes: 1
devices: ${device_count:}
accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
gradient_clip_val: 1.0
precision: bf16
num_sanity_val_steps: 2
max_epochs: 60
max_steps: 1_000_000
log_every_n_steps: 10
limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run
val_check_interval: 955
wandb:
project: MDpLM_finetune_membrane_200k-seqs
notes: null
group: programmablebio
job_type: null
name: dit_test #dit_wrapping_epochs60_lr3e-4_200k-seqs_bsz16_all-params_no-compile_gradclip1_beta-one0.9_beta-two0.999_bf16
id: ${.name}_${seed}
hydra:
run:
dir: /workspace/sg666/MDpLM/outputs/${data.train}/${now:%Y.%m.%d}/${now:%H%M%S}
job:
chdir: true
checkpointing:
# Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
save_dir: /workspace/sg666/MDpLM/checkpoints/membrane_mdlm/
# Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
resume_from_ckpt: false
resume_ckpt_path: ${.save_dir}/epochs30_lr3e-4_bsz8_gradclip1_beta-one0.9_beta-two0.999_bf16_all-params_no-compile/checkpoints/last.ckpt #/checkpoints/last.ckpt
pretrained_esm_mdlm_automodel_path: /workspace/sg666/MDpLM/checkpoints/vanilla_esm_pretrained_automodel/epochs10_lr3e-4_200k-seqs_bsz16_all-params_no-compile_gradclip1_beta-one0.9_beta-two0.999_bf16/
finetuned_esm_mdlm_automodel_path: /workspace/sg666/MDpLM/checkpoints/membrane_mdlm/