# Tokenizer
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: ../tokenizer.model

# Dataset and Sampler
dataset:
  _component_: custom_datasets.orpo_dpo_mix_40k_dataset
  max_seq_len: 8196
#dataset:
#  _component_: torchtune.datasets.stack_exchanged_paired_dataset
seed: 42
shuffle: True
batch_size: 2

# Model Arguments
model:
  _component_: torchtune.models.llama3.llama3
  vocab_size: 128256
  num_layers: 20
  num_heads: 32
  num_kv_heads: 8
  embed_dim: 4096
  max_seq_len: 8196
  intermediate_dim: 14336
  attn_dropout: 0.0
  norm_eps: 1e-5
  rope_base: 500000.0

checkpointer:
  _component_: torchtune.utils.FullModelHFCheckpointer
  checkpoint_dir: ../
  checkpoint_files: [
    model-00001-of-00003.safetensors,
    model-00002-of-00003.safetensors,
    model-00003-of-00003.safetensors
  ]
  recipe_checkpoint: null
  output_dir: ./llama3-5b/
  model_type: LLAMA3
resume_from_checkpoint: False

# Fine-tuning arguments
epochs: 5
optimizer:
  _component_: torch.optim.AdamW #bitsandbytes.optim.PagedAdamW8bit
  lr: 3e-6
lr_scheduler:
  _component_: torchtune.modules.get_cosine_schedule_with_warmup
  num_warmup_steps: 1500
#loss:
#  _component_: torchtune.modules.loss.DPOLoss
#  beta: 0.1
#  label_smoothing: 0
#  loss_type: sigmoid
loss:
  _component_: torch.nn.CrossEntropyLoss

max_steps_per_epoch: null
gradient_accumulation_steps: 2
optimizer_in_bwd: False  # False if grad accum > 1
compile: False

# Training environment
device: cuda

# Memory management
enable_activation_checkpointing: True

# Reduced precision
dtype: bf16 #fp32

# Logging
# enable logging to the built-in WandBLogger
metric_logger:
  _component_: torchtune.utils.metric_logging.WandBLogger
  # the W&B project to log to
  project: llama3-5b
output_dir: ./logs/
log_every_n_steps: 1
log_peak_memory_stats: False