Update README.md
Browse files
README.md
CHANGED
@@ -65,4 +65,43 @@ Cite TRL as:
|
|
65 |
publisher = {GitHub},
|
66 |
howpublished = {\url{https://github.com/huggingface/trl}}
|
67 |
}
|
68 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
publisher = {GitHub},
|
66 |
howpublished = {\url{https://github.com/huggingface/trl}}
|
67 |
}
|
68 |
+
```
|
69 |
+
|
70 |
+
#Train the model
|
71 |
+
training_args = DPOConfig(
|
72 |
+
output_dir="llava-lora-12-06-rpo-0.1",
|
73 |
+
bf16=True,
|
74 |
+
gradient_checkpointing=True,
|
75 |
+
per_device_train_batch_size=8,
|
76 |
+
per_device_eval_batch_size=4,
|
77 |
+
gradient_accumulation_steps=32,
|
78 |
+
evaluation_strategy="steps",
|
79 |
+
eval_steps=1,
|
80 |
+
learning_rate=1e-5,
|
81 |
+
beta=0.1,
|
82 |
+
warmup_ratio=0.1,
|
83 |
+
lr_scheduler_type="cosine",
|
84 |
+
num_train_epochs=2,
|
85 |
+
rpo_alpha=0.1,
|
86 |
+
dataset_num_proc=32, # tokenization will use 32 processes
|
87 |
+
dataloader_num_workers=32, # data loading will use 32 workers
|
88 |
+
logging_steps=1,
|
89 |
+
)
|
90 |
+
|
91 |
+
#Define LoRA configuration with specified rank
|
92 |
+
lora_config = LoraConfig(
|
93 |
+
r=64, # Set rank to 64
|
94 |
+
lora_alpha=128, # Set scaling factor to 128
|
95 |
+
target_modules="all-linear", # Target all linear layers
|
96 |
+
lora_dropout=0.1,
|
97 |
+
)
|
98 |
+
|
99 |
+
trainer = DPOTrainer(
|
100 |
+
model,
|
101 |
+
ref_model=None, # not needed when using peft
|
102 |
+
args=training_args,
|
103 |
+
train_dataset=train_dataset,
|
104 |
+
eval_dataset=eval_dataset,
|
105 |
+
tokenizer=processor,
|
106 |
+
peft_config=lora_config,
|
107 |
+
)
|