winglian commited on
Commit
f7332ac
·
unverified ·
1 Parent(s): 16d46b7

use mixins for orpo and kto configs so they work with axolotl customizations (#1674)

Browse files
Files changed (1) hide show
  1. src/axolotl/core/trainer_builder.py +31 -6
src/axolotl/core/trainer_builder.py CHANGED
@@ -91,11 +91,12 @@ def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
91
 
92
 
93
  @dataclass
94
- class AxolotlTrainingArguments(TrainingArguments):
95
  """
96
- Extend the base TrainingArguments for axolotl helpers
97
  """
98
 
 
99
  model_type: Optional[str] = field(
100
  default=None, metadata={"help": "HF model configuration model_type."}
101
  )
@@ -227,6 +228,30 @@ class AxolotlTrainingArguments(TrainingArguments):
227
  )
228
 
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  class AxolotlTrainer(Trainer):
231
  """
232
  Extend the base Trainer for axolotl helpers
@@ -1583,14 +1608,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
1583
 
1584
  training_args_cls = AxolotlTrainingArguments
1585
  if self.cfg.rl == "orpo":
1586
- training_args_cls = ORPOConfig
1587
  training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
1588
  training_args_kwargs["max_length"] = self.cfg.sequence_len
1589
  if self.cfg.max_prompt_len:
1590
  training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
1591
 
1592
  if self.cfg.rl == "kto":
1593
- training_args_cls = KTOConfig
1594
 
1595
  training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1
1596
  training_args_kwargs["desirable_weight"] = (
@@ -1605,12 +1630,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
1605
  if self.cfg.max_prompt_len:
1606
  training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
1607
 
1608
- training_args = training_args_cls(
 
1609
  per_device_train_batch_size=self.cfg.micro_batch_size,
1610
  max_steps=self.cfg.max_steps or total_num_steps,
1611
  gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
1612
  learning_rate=self.cfg.learning_rate,
1613
- output_dir=self.cfg.output_dir,
1614
  warmup_steps=self.cfg.warmup_steps,
1615
  logging_first_step=True,
1616
  logging_steps=1,
 
91
 
92
 
93
  @dataclass
94
+ class AxolotlTrainingMixins:
95
  """
96
+ Mixin class for the Axolotl training args.
97
  """
98
 
99
+ # pylint: disable=duplicate-code
100
  model_type: Optional[str] = field(
101
  default=None, metadata={"help": "HF model configuration model_type."}
102
  )
 
228
  )
229
 
230
 
231
+ @dataclass
232
+ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
233
+ """
234
+ Training arguments for Causal trainer
235
+
236
+ This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
237
+ so it can't be used as a mixin.
238
+ """
239
+
240
+
241
+ @dataclass
242
+ class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
243
+ """
244
+ ORPO config for ORPO training
245
+ """
246
+
247
+
248
+ @dataclass
249
+ class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig):
250
+ """
251
+ KTO config for KTO training
252
+ """
253
+
254
+
255
  class AxolotlTrainer(Trainer):
256
  """
257
  Extend the base Trainer for axolotl helpers
 
1608
 
1609
  training_args_cls = AxolotlTrainingArguments
1610
  if self.cfg.rl == "orpo":
1611
+ training_args_cls = AxolotlORPOConfig
1612
  training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
1613
  training_args_kwargs["max_length"] = self.cfg.sequence_len
1614
  if self.cfg.max_prompt_len:
1615
  training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
1616
 
1617
  if self.cfg.rl == "kto":
1618
+ training_args_cls = AxolotlKTOConfig
1619
 
1620
  training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1
1621
  training_args_kwargs["desirable_weight"] = (
 
1630
  if self.cfg.max_prompt_len:
1631
  training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
1632
 
1633
+ training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
1634
+ output_dir=self.cfg.output_dir,
1635
  per_device_train_batch_size=self.cfg.micro_batch_size,
1636
  max_steps=self.cfg.max_steps or total_num_steps,
1637
  gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
1638
  learning_rate=self.cfg.learning_rate,
 
1639
  warmup_steps=self.cfg.warmup_steps,
1640
  logging_first_step=True,
1641
  logging_steps=1,