max_seq_length = 500

def fmt(examples):
    print(len(examples))
    return examples
    
# 'lora_r' is the dimension of the LoRA attention.
lora_r = 32

# 'lora_alpha' is the alpha parameter for LoRA scaling.
lora_alpha = 16

# 'lora_dropout' is the dropout probability for LoRA layers.
lora_dropout = 0.05

# 'target_modules' is a list of the modules that should be targeted by LoRA.
target_modules= ['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"]

# 'se

peft_config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        lora_dropout=lora_dropout,
        task_type=TaskType.CAUSAL_LM,
        target_modules=target_modules,
)

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = qa_dataset['train'],
    eval_dataset = qa_dataset['test'],
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 4,
    data_collator = collator,
    # formatting_func = fmt,
    # peft_config=peft_config,
    args = TrainingArguments(
        per_device_train_batch_size = 6,
        gradient_checkpointing = True,
        gradient_accumulation_steps = 4,
        per_device_eval_batch_size = 40,
        do_eval = True,
        eval_strategy = 'steps',
        eval_steps = 50,
        # save_strategy = 'steps',
        save_steps = 1000,

        # Use num_train_epochs and warmup_ratio for longer runs!
        # max_steps = 70,
        # warmup_steps = 10,
        # warmup_ratio = 0.1,
        num_train_epochs = 2,

        # Select a 2 to 10x smaller learning rate for the embedding matrices!
        learning_rate = 3e-5,
        # embedding_learning_rate = 1e-6,

        # fp16 = not is_bfloat16_supported(),
        bf16 = True,
        logging_steps = 1,
        optim = "adamw_torch",
        weight_decay = 0.00,
        lr_scheduler_type = "linear",
        # seed = 3407,

        output_dir = "llama_3b_step2_batch_v4",
    ),
)