defaults: - _self_ - /callbacks: [checkpoint_every_n_steps, checkpoint_monitor, learning_rate_monitor] - /model: small - /strategy: ddp - /noise: loglinear - /lr_scheduler: constant_warmup mode: sample_eval # train / ppl_eval / sample_eval diffusion: absorbing_state backbone: membrane_esm_finetune # dit / dimamba / ar / vanilla_esm_pretrain / membrane_esm_finetune parameterization: subs # subs / d3pm / sedd time_conditioning: False T: 0 # 0 (continuous time) / 1000 subs_masking: False seed: 42 data: train: vanilla_esm_train_path: /workspace/sg666/MDpLM/data/uniref50/200k_seqs/train.csv membrane_esm_train_path: /workspace/sg666/MDpLM/data/membrane/train.csv wrap: null test: vanilla_esm_test_path: /workspace/sg666/MDpLM/data/uniref50/200k_seqs/test.csv membrane_esm_test_path: /workspace/sg666/MDpLM/data/membrane/test.csv wrap: null valid: vanilla_esm_valid_path: /workspace/sg666/MDpLM/data/uniref50/200k_seqs/val.csv membrane_esm_valid_path: /workspace/sg666/MDpLM/data/membrane/val.csv wrap: null wrapping: True loader: global_batch_size: 8 eval_global_batch_size: ${.global_batch_size} # Note: batch_size and eval_batch_size are **per machine** batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}} eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}} num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"} pin_memory: True sampling: predictor: ddpm_cache # analytic, ddpm, ddpm_cache steps: 128 noise_removal: True # TODO(yair): @subham, why aren't these params under `eval`? num_sample_batches: 2 # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches num_sample_log: 2 semi_ar: False stride_length: 1 num_strides: 1 training: ema: 0.9999 antithetic_sampling: True importance_sampling: False sampling_eps: 1e-3 change_of_variables: False mlm_model_path: /workspace/sg666/MDpLM/benchmarks/MLM/model_ckpts_650M/best_model_epoch esm_model_path: facebook/esm2_t30_150M_UR50D focus_mask: False eval: checkpoint_path: /workspace/sg666/MDpLM/checkpoints/membrane_mdlm/eos-wrapping_epochs60_lr3e-4_200k-seqs_bsz16_all-params_no-compile_gradclip1_beta-one0.9_beta-two0.999_bf16/checkpoints/best.ckpt # Used to evaluate a checkpoint after training. disable_ema: False compute_generative_perplexity: False perplexity_batch_size: 8 compute_perplexity_on_sanity: False gen_ppl_eval_model_name_or_path: gpt2-large # gpt2-large, meta-llama/Llama-2-7b-hf generate_samples: True generation_model: /workspace/sg666/MDpLM/checkpoints/membrane_automodel/epochs60_lr3e-4_200k-seqs_bsz16_all-params_no-compile_gradclip1_beta-one0.9_beta-two0.999_bf16/ optim: weight_decay: 0.075 lr: 3e-4 beta1: 0.9 beta2: 0.999 eps: 1e-8 Model: hidden_size: 1280 cond_dim: 256 n_heads: 20 n_blocks: 4 dropout: 0.5 length: null #512 scale_by_sigma: True trainer: _target_: lightning.Trainer accelerator: cuda num_nodes: 1 devices: ${device_count:} accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}} gradient_clip_val: 1.0 precision: bf16 num_sanity_val_steps: 2 max_epochs: 60 max_steps: 1_000_000 log_every_n_steps: 10 limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run val_check_interval: 955 wandb: project: MDpLM_finetune_membrane_200k-seqs notes: null group: programmablebio job_type: null name: dit_test #dit_wrapping_epochs60_lr3e-4_200k-seqs_bsz16_all-params_no-compile_gradclip1_beta-one0.9_beta-two0.999_bf16 id: ${.name}_${seed} hydra: run: dir: /workspace/sg666/MDpLM/outputs/${data.train}/${now:%Y.%m.%d}/${now:%H%M%S} job: chdir: true checkpointing: # Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is save_dir: /workspace/sg666/MDpLM/checkpoints/membrane_mdlm/ # Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath` resume_from_ckpt: false resume_ckpt_path: ${.save_dir}/epochs30_lr3e-4_bsz8_gradclip1_beta-one0.9_beta-two0.999_bf16_all-params_no-compile/checkpoints/last.ckpt #/checkpoints/last.ckpt pretrained_esm_mdlm_automodel_path: /workspace/sg666/MDpLM/checkpoints/vanilla_esm_pretrained_automodel/epochs10_lr3e-4_200k-seqs_bsz16_all-params_no-compile_gradclip1_beta-one0.9_beta-two0.999_bf16/ finetuned_esm_mdlm_automodel_path: /workspace/sg666/MDpLM/checkpoints/membrane_mdlm/