dpo: | |
data: | |
splits: "random" # {random, novel_premises} | |
train_size: 0.8 | |
include_next_state: false | |
# paths are relative to the project root | |
raw_data: "data/time_filtered_v3.json" | |
formatted_dataset_dir: "data/straight_shot_proof_sample/filtered_negative_tactics_dataset.json" | |
expand_records: false | |
# processed_data: "data/straight_shot_proof_sample/dpo_expanded_dataset" | |
# processed_data: "data/straight_shot_proof_sample/dpo_single_entry_dataset" | |
processed_data: "data/straight_shot_proof_sample/dpo_flattened_dataset" | |
# prompt formatting {llemma, deepseek} | |
# model_prompt_template: "deepseek" | |
model_prompt_template: "llemma" | |
use_sts_format: false | |
model: | |
# TODO: figure this out | |
base_model_id: "deepseek-ai/DeepSeek-Prover-V1" | |
max_seq_length: 1024 | |
packing: true # pack examples together for better efficiency | |
training_args: | |
# output_dir: "dspv1_dpo_dspfmt_medium" # directory to save and repository id (relative to project root) | |
output_dir: "dspv1_dpo_llemmafmt_medium" # directory to save and repository id (relative to project root) | |
num_train_epochs: 3 # number of training epochs | |
per_device_train_batch_size: 3 # batch size per device during training | |
gradient_accumulation_steps: 2 # number of steps before performing a backward/update pass | |
gradient_checkpointing: true # use gradient checkpointing to save memory | |
optim: "adamw_torch_fused" # use fused adamw optimizer | |
logging_steps: 10 # log every 10 steps | |
save_strategy: "epoch" # save checkpoint every epoch | |
learning_rate: 0.0002 # learning rate, based on QLoRA paper | |
bf16: true # use bfloat16 precision | |
tf32: true # use tf32 precision | |
max_grad_norm: 0.3 # max gradient norm based on QLoRA paper | |
warmup_ratio: 0.03 # warmup ratio based on QLoRA paper | |
lr_scheduler_type: "constant" # use constant learning rate scheduler | |
push_to_hub: true # push model to hub | |
report_to: "tensorboard" # report metrics to tensorboard | |
beta: 0.01 # # TODO: tune this (beta for the loss function) | |
bnb: | |
_target_: transformers.BitsAndBytesConfig | |
load_in_4bit: true | |
bnb_4bit_use_double_quant: true | |
bnb_4bit_quant_type: nf4 | |
bnb_4bit_compute_dtype: bfloat16 | |
lora: | |
_target_: peft.LoraConfig | |
r: 16 | |
lora_alpha: 32 | |
lora_dropout: 0.05 | |
bias: "none" | |
target_modules: "all-linear" | |
task_type: "CAUSAL_LM" |