drop empty token from beginning if tokenizer has no bos_token (in the case of qwen) (#1490)
Browse files
src/axolotl/core/trainer_builder.py
CHANGED
@@ -23,6 +23,7 @@ from torch.optim.lr_scheduler import OneCycleLR
|
|
23 |
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
24 |
from transformers import (
|
25 |
EarlyStoppingCallback,
|
|
|
26 |
Trainer,
|
27 |
TrainerCallback,
|
28 |
TrainingArguments,
|
@@ -802,6 +803,15 @@ class AxolotlDPOTrainer(DPOTrainer):
|
|
802 |
|
803 |
return super().push_to_hub(*args, **kwargs)
|
804 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
805 |
|
806 |
class TrainerBuilderBase(abc.ABC):
|
807 |
"""
|
|
|
23 |
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
24 |
from transformers import (
|
25 |
EarlyStoppingCallback,
|
26 |
+
PreTrainedModel,
|
27 |
Trainer,
|
28 |
TrainerCallback,
|
29 |
TrainingArguments,
|
|
|
803 |
|
804 |
return super().push_to_hub(*args, **kwargs)
|
805 |
|
806 |
+
def tokenize_row(
|
807 |
+
self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None
|
808 |
+
) -> Dict:
|
809 |
+
res = super().tokenize_row(feature, model=model)
|
810 |
+
if self.tokenizer.bos_token_id is None and res["prompt_input_ids"][0] is None:
|
811 |
+
for key in res.keys():
|
812 |
+
res[key] = res[key][1:]
|
813 |
+
return res
|
814 |
+
|
815 |
|
816 |
class TrainerBuilderBase(abc.ABC):
|
817 |
"""
|