dspv1_dpo_llemmafmt_medium / dpo_training.yaml
vincentlinzhu's picture
Training in progress, epoch 1
ed8a219 verified
dpo:
data:
splits: "random" # {random, novel_premises}
train_size: 0.8
include_next_state: false
# paths are relative to the project root
raw_data: "data/time_filtered_v3.json"
formatted_dataset_dir: "data/straight_shot_proof_sample/filtered_negative_tactics_dataset.json"
expand_records: false
# processed_data: "data/straight_shot_proof_sample/dpo_expanded_dataset"
# processed_data: "data/straight_shot_proof_sample/dpo_single_entry_dataset"
processed_data: "data/straight_shot_proof_sample/dpo_flattened_dataset"
# prompt formatting {llemma, deepseek}
# model_prompt_template: "deepseek"
model_prompt_template: "llemma"
use_sts_format: false
model:
# TODO: figure this out
base_model_id: "deepseek-ai/DeepSeek-Prover-V1"
max_seq_length: 1024
packing: true # pack examples together for better efficiency
training_args:
# output_dir: "dspv1_dpo_dspfmt_medium" # directory to save and repository id (relative to project root)
output_dir: "dspv1_dpo_llemmafmt_medium" # directory to save and repository id (relative to project root)
num_train_epochs: 3 # number of training epochs
per_device_train_batch_size: 3 # batch size per device during training
gradient_accumulation_steps: 2 # number of steps before performing a backward/update pass
gradient_checkpointing: true # use gradient checkpointing to save memory
optim: "adamw_torch_fused" # use fused adamw optimizer
logging_steps: 10 # log every 10 steps
save_strategy: "epoch" # save checkpoint every epoch
learning_rate: 0.0002 # learning rate, based on QLoRA paper
bf16: true # use bfloat16 precision
tf32: true # use tf32 precision
max_grad_norm: 0.3 # max gradient norm based on QLoRA paper
warmup_ratio: 0.03 # warmup ratio based on QLoRA paper
lr_scheduler_type: "constant" # use constant learning rate scheduler
push_to_hub: true # push model to hub
report_to: "tensorboard" # report metrics to tensorboard
beta: 0.01 # # TODO: tune this (beta for the loss function)
bnb:
_target_: transformers.BitsAndBytesConfig
load_in_4bit: true
bnb_4bit_use_double_quant: true
bnb_4bit_quant_type: nf4
bnb_4bit_compute_dtype: bfloat16
lora:
_target_: peft.LoraConfig
r: 16
lora_alpha: 32
lora_dropout: 0.05
bias: "none"
target_modules: "all-linear"
task_type: "CAUSAL_LM"