"""Callbacks for Trainer class""" | |
import os | |
from optimum.bettertransformer import BetterTransformer | |
from transformers import ( | |
TrainerCallback, | |
TrainerControl, | |
TrainerState, | |
TrainingArguments, | |
) | |
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy | |
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods | |
"""Callback to save the PEFT adapter""" | |
def on_save( | |
self, | |
args: TrainingArguments, | |
state: TrainerState, | |
control: TrainerControl, | |
**kwargs, | |
): | |
checkpoint_folder = os.path.join( | |
args.output_dir, | |
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", | |
) | |
peft_model_path = os.path.join(checkpoint_folder, "adapter_model") | |
kwargs["model"].save_pretrained(peft_model_path) | |
return control | |
class SaveBetterTransformerModelCallback( | |
TrainerCallback | |
): # pylint: disable=too-few-public-methods | |
"""Callback to save the BetterTransformer wrapped model""" | |
def on_step_end( | |
self, | |
args: TrainingArguments, | |
state: TrainerState, | |
control: TrainerControl, | |
**kwargs, | |
): | |
# Save | |
if ( | |
args.save_strategy == IntervalStrategy.STEPS | |
and args.save_steps > 0 | |
and state.global_step % args.save_steps == 0 | |
): | |
control.should_save = True | |
if control.should_save: | |
checkpoint_folder = os.path.join( | |
args.output_dir, | |
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", | |
) | |
model = BetterTransformer.reverse(kwargs["model"]) | |
model.save_pretrained(checkpoint_folder) | |
# since we're saving here, we don't need the trainer loop to attempt to save too b/c | |
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model | |
control.should_save = False | |
return control | |