|
import json |
|
import os |
|
import time |
|
from datetime import timedelta |
|
from typing import TYPE_CHECKING |
|
|
|
from transformers import TrainerCallback |
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length |
|
|
|
from .constants import LOG_FILE_NAME |
|
from .logging import get_logger |
|
from .misc import fix_valuehead_checkpoint |
|
|
|
|
|
if TYPE_CHECKING: |
|
from transformers import TrainerControl, TrainerState, TrainingArguments |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
class FixValueHeadModelCallback(TrainerCallback): |
|
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
|
r""" |
|
Event called after a checkpoint save. |
|
""" |
|
if args.should_save: |
|
fix_valuehead_checkpoint( |
|
model=kwargs.pop("model"), |
|
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)), |
|
safe_serialization=args.save_safetensors, |
|
) |
|
|
|
|
|
class LogCallback(TrainerCallback): |
|
def __init__(self, runner=None): |
|
self.runner = runner |
|
self.in_training = False |
|
self.start_time = time.time() |
|
self.cur_steps = 0 |
|
self.max_steps = 0 |
|
self.elapsed_time = "" |
|
self.remaining_time = "" |
|
|
|
def timing(self): |
|
cur_time = time.time() |
|
elapsed_time = cur_time - self.start_time |
|
avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0 |
|
remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step |
|
self.elapsed_time = str(timedelta(seconds=int(elapsed_time))) |
|
self.remaining_time = str(timedelta(seconds=int(remaining_time))) |
|
|
|
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
|
r""" |
|
Event called at the beginning of training. |
|
""" |
|
if state.is_local_process_zero: |
|
self.in_training = True |
|
self.start_time = time.time() |
|
self.max_steps = state.max_steps |
|
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir: |
|
logger.warning("Previous log file in this folder will be deleted.") |
|
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME)) |
|
|
|
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
|
r""" |
|
Event called at the end of training. |
|
""" |
|
if state.is_local_process_zero: |
|
self.in_training = False |
|
self.cur_steps = 0 |
|
self.max_steps = 0 |
|
|
|
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
|
r""" |
|
Event called at the end of an substep during gradient accumulation. |
|
""" |
|
if state.is_local_process_zero and self.runner is not None and self.runner.aborted: |
|
control.should_epoch_stop = True |
|
control.should_training_stop = True |
|
|
|
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
|
r""" |
|
Event called at the end of a training step. |
|
""" |
|
if state.is_local_process_zero: |
|
self.cur_steps = state.global_step |
|
self.timing() |
|
if self.runner is not None and self.runner.aborted: |
|
control.should_epoch_stop = True |
|
control.should_training_stop = True |
|
|
|
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
|
r""" |
|
Event called after an evaluation phase. |
|
""" |
|
if state.is_local_process_zero and not self.in_training: |
|
self.cur_steps = 0 |
|
self.max_steps = 0 |
|
|
|
def on_predict( |
|
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs |
|
): |
|
r""" |
|
Event called after a successful prediction. |
|
""" |
|
if state.is_local_process_zero and not self.in_training: |
|
self.cur_steps = 0 |
|
self.max_steps = 0 |
|
|
|
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None: |
|
r""" |
|
Event called after logging the last logs. |
|
""" |
|
if not state.is_local_process_zero: |
|
return |
|
|
|
logs = dict( |
|
current_steps=self.cur_steps, |
|
total_steps=self.max_steps, |
|
loss=state.log_history[-1].get("loss", None), |
|
eval_loss=state.log_history[-1].get("eval_loss", None), |
|
predict_loss=state.log_history[-1].get("predict_loss", None), |
|
reward=state.log_history[-1].get("reward", None), |
|
learning_rate=state.log_history[-1].get("learning_rate", None), |
|
epoch=state.log_history[-1].get("epoch", None), |
|
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, |
|
elapsed_time=self.elapsed_time, |
|
remaining_time=self.remaining_time, |
|
) |
|
if self.runner is not None: |
|
logger.info( |
|
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format( |
|
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0 |
|
) |
|
) |
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f: |
|
f.write(json.dumps(logs) + "\n") |
|
|
|
def on_prediction_step( |
|
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs |
|
): |
|
r""" |
|
Event called after a prediction step. |
|
""" |
|
eval_dataloader = kwargs.pop("eval_dataloader", None) |
|
if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training: |
|
if self.max_steps == 0: |
|
self.max_steps = len(eval_dataloader) |
|
self.cur_steps += 1 |
|
self.timing() |
|
|