tess-2-demo / sdlm /trainers /trainer_ar.py
hamishivi's picture
commit
17ff0d8 verified
import math
import time
from typing import Dict, List, Optional
from torch.utils.data import DataLoader, Dataset
from transformers import Seq2SeqTrainer
from transformers.integrations import TensorBoardCallback
from transformers.trainer_utils import speed_metrics
from transformers.utils import logging
skip_first_batches = None
IS_SAGEMAKER_MP_POST_1_10 = False
GENERATION_RESULTS = "generated"
logger = logging.get_logger(__name__)
class ARTrainer(Seq2SeqTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, preprocess_logits_for_metrics=None)
self.tb_writer = self.get_tb_writer()
self.original_data_collator = self.data_collator
def get_tb_writer(self):
for cb in self.callback_handler.callbacks:
if isinstance(cb, TensorBoardCallback):
return cb
return None
def log_results_to_tensorboard(self, output):
# TODO: we need to fix this which happens during the only eval option.
if self.tb_writer.tb_writer is None:
return
for i, (label, prediction) in enumerate(
zip(output.label_ids, output.predictions)
):
try:
total_text = ""
decoded_label = self.tokenizer.decode(label[label != -100])
decoded_prediction = self.tokenizer.decode(
prediction[prediction != -100]
)
total_text += f"*** label ***: {decoded_label} \n"
total_text += f"*** prediction ***: {decoded_prediction}"
self.tb_writer.tb_writer.add_text(
f"sample_{i}", total_text, self.state.global_step
)
except OverflowError:
print("[ERROR] tokenization", prediction)
def get_train_dataloader(self) -> DataLoader:
self.data_collator = self.original_data_collator("train")
return super().get_train_dataloader()
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
self.data_collator = self.original_data_collator("eval")
return super().get_eval_dataloader(eval_dataset)
def evaluate(
self,
eval_dataset: Optional[Dataset] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
**gen_kwargs,
) -> Dict[str, float]:
"""
Copied from
- https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py
- https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_seq2seq.py
with added tensorboard text logging.
"""
gen_kwargs = gen_kwargs.copy()
# Use legacy argument setting if a) the option is not explicitly passed; and b) the argument is set in the
# training args
if (
gen_kwargs.get("max_length") is None
and gen_kwargs.get("max_new_tokens") is None
and self.args.generation_max_length is not None
):
gen_kwargs["max_length"] = self.args.generation_max_length
if (
gen_kwargs.get("num_beams") is None
and self.args.generation_num_beams is not None
):
gen_kwargs["num_beams"] = self.args.generation_num_beams
# We don't want to drop samples in general
self.gather_function = self.accelerator.gather
self._gen_kwargs = gen_kwargs
# return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
if isinstance(eval_dataset, dict):
metrics = {}
for eval_dataset_name, _eval_dataset in eval_dataset.items():
dataset_metrics = self.evaluate(
eval_dataset=_eval_dataset,
ignore_keys=ignore_keys,
metric_key_prefix=f"{metric_key_prefix}_{eval_dataset_name}",
)
metrics.update(dataset_metrics)
return metrics
# memory metrics - must set up as early as possible
self._memory_tracker.start()
eval_dataloader = self.get_eval_dataloader(eval_dataset)
# NOTE: no tpu
# if self.is_fsdp_xla_v2_enabled:
# eval_dataloader = tpu_spmd_dataloader(eval_dataloader)
start_time = time.time()
eval_loop = (
self.prediction_loop
if self.args.use_legacy_prediction_loop
else self.evaluation_loop
)
output = eval_loop(
eval_dataloader,
description="Evaluation",
# No point gathering the predictions if there are no metrics, otherwise we defer to
# self.args.prediction_loss_only
prediction_loss_only=True if self.compute_metrics is None else None,
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
)
total_batch_size = self.args.eval_batch_size * self.args.world_size
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
output.metrics.update(
speed_metrics(
metric_key_prefix,
start_time,
num_samples=output.num_samples,
num_steps=math.ceil(output.num_samples / total_batch_size),
)
)
self.log(output.metrics)
# NOTE: no tpu
# if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
# # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
# xm.master_print(met.metrics_report())
self.control = self.callback_handler.on_evaluate(
self.args, self.state, self.control, output.metrics
)
self._memory_tracker.stop_and_update_metrics(output.metrics)
# NOTE: text logging
if self.args.log_generated_texts:
self.log_results_to_tensorboard(output)
return output.metrics