winglian commited on
Commit
934fc85
·
unverified ·
1 Parent(s): bda48f0

drop empty token from beginning if tokenizer has no bos_token (in the case of qwen) (#1490)

Browse files
Files changed (1) hide show
  1. src/axolotl/core/trainer_builder.py +10 -0
src/axolotl/core/trainer_builder.py CHANGED
@@ -23,6 +23,7 @@ from torch.optim.lr_scheduler import OneCycleLR
23
  from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
24
  from transformers import (
25
  EarlyStoppingCallback,
 
26
  Trainer,
27
  TrainerCallback,
28
  TrainingArguments,
@@ -802,6 +803,15 @@ class AxolotlDPOTrainer(DPOTrainer):
802
 
803
  return super().push_to_hub(*args, **kwargs)
804
 
 
 
 
 
 
 
 
 
 
805
 
806
  class TrainerBuilderBase(abc.ABC):
807
  """
 
23
  from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
24
  from transformers import (
25
  EarlyStoppingCallback,
26
+ PreTrainedModel,
27
  Trainer,
28
  TrainerCallback,
29
  TrainingArguments,
 
803
 
804
  return super().push_to_hub(*args, **kwargs)
805
 
806
+ def tokenize_row(
807
+ self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None
808
+ ) -> Dict:
809
+ res = super().tokenize_row(feature, model=model)
810
+ if self.tokenizer.bos_token_id is None and res["prompt_input_ids"][0] is None:
811
+ for key in res.keys():
812
+ res[key] = res[key][1:]
813
+ return res
814
+
815
 
816
  class TrainerBuilderBase(abc.ABC):
817
  """