use mixins for orpo and kto configs so they work with axolotl customizations (#1674)
Browse files
src/axolotl/core/trainer_builder.py
CHANGED
@@ -91,11 +91,12 @@ def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
|
91 |
|
92 |
|
93 |
@dataclass
|
94 |
-
class
|
95 |
"""
|
96 |
-
|
97 |
"""
|
98 |
|
|
|
99 |
model_type: Optional[str] = field(
|
100 |
default=None, metadata={"help": "HF model configuration model_type."}
|
101 |
)
|
@@ -227,6 +228,30 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|
227 |
)
|
228 |
|
229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
class AxolotlTrainer(Trainer):
|
231 |
"""
|
232 |
Extend the base Trainer for axolotl helpers
|
@@ -1583,14 +1608,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|
1583 |
|
1584 |
training_args_cls = AxolotlTrainingArguments
|
1585 |
if self.cfg.rl == "orpo":
|
1586 |
-
training_args_cls =
|
1587 |
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
1588 |
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
1589 |
if self.cfg.max_prompt_len:
|
1590 |
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
1591 |
|
1592 |
if self.cfg.rl == "kto":
|
1593 |
-
training_args_cls =
|
1594 |
|
1595 |
training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1
|
1596 |
training_args_kwargs["desirable_weight"] = (
|
@@ -1605,12 +1630,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|
1605 |
if self.cfg.max_prompt_len:
|
1606 |
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
1607 |
|
1608 |
-
training_args = training_args_cls(
|
|
|
1609 |
per_device_train_batch_size=self.cfg.micro_batch_size,
|
1610 |
max_steps=self.cfg.max_steps or total_num_steps,
|
1611 |
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
1612 |
learning_rate=self.cfg.learning_rate,
|
1613 |
-
output_dir=self.cfg.output_dir,
|
1614 |
warmup_steps=self.cfg.warmup_steps,
|
1615 |
logging_first_step=True,
|
1616 |
logging_steps=1,
|
|
|
91 |
|
92 |
|
93 |
@dataclass
|
94 |
+
class AxolotlTrainingMixins:
|
95 |
"""
|
96 |
+
Mixin class for the Axolotl training args.
|
97 |
"""
|
98 |
|
99 |
+
# pylint: disable=duplicate-code
|
100 |
model_type: Optional[str] = field(
|
101 |
default=None, metadata={"help": "HF model configuration model_type."}
|
102 |
)
|
|
|
228 |
)
|
229 |
|
230 |
|
231 |
+
@dataclass
|
232 |
+
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
233 |
+
"""
|
234 |
+
Training arguments for Causal trainer
|
235 |
+
|
236 |
+
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
|
237 |
+
so it can't be used as a mixin.
|
238 |
+
"""
|
239 |
+
|
240 |
+
|
241 |
+
@dataclass
|
242 |
+
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
|
243 |
+
"""
|
244 |
+
ORPO config for ORPO training
|
245 |
+
"""
|
246 |
+
|
247 |
+
|
248 |
+
@dataclass
|
249 |
+
class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig):
|
250 |
+
"""
|
251 |
+
KTO config for KTO training
|
252 |
+
"""
|
253 |
+
|
254 |
+
|
255 |
class AxolotlTrainer(Trainer):
|
256 |
"""
|
257 |
Extend the base Trainer for axolotl helpers
|
|
|
1608 |
|
1609 |
training_args_cls = AxolotlTrainingArguments
|
1610 |
if self.cfg.rl == "orpo":
|
1611 |
+
training_args_cls = AxolotlORPOConfig
|
1612 |
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
1613 |
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
1614 |
if self.cfg.max_prompt_len:
|
1615 |
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
1616 |
|
1617 |
if self.cfg.rl == "kto":
|
1618 |
+
training_args_cls = AxolotlKTOConfig
|
1619 |
|
1620 |
training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1
|
1621 |
training_args_kwargs["desirable_weight"] = (
|
|
|
1630 |
if self.cfg.max_prompt_len:
|
1631 |
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
1632 |
|
1633 |
+
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
1634 |
+
output_dir=self.cfg.output_dir,
|
1635 |
per_device_train_batch_size=self.cfg.micro_batch_size,
|
1636 |
max_steps=self.cfg.max_steps or total_num_steps,
|
1637 |
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
1638 |
learning_rate=self.cfg.learning_rate,
|
|
|
1639 |
warmup_steps=self.cfg.warmup_steps,
|
1640 |
logging_first_step=True,
|
1641 |
logging_steps=1,
|