Spaces:
Sleeping
Sleeping
import math | |
import time | |
from typing import Dict, List, Optional | |
from torch.utils.data import DataLoader, Dataset | |
from transformers import Seq2SeqTrainer | |
from transformers.integrations import TensorBoardCallback | |
from transformers.trainer_utils import speed_metrics | |
from transformers.utils import logging | |
skip_first_batches = None | |
IS_SAGEMAKER_MP_POST_1_10 = False | |
GENERATION_RESULTS = "generated" | |
logger = logging.get_logger(__name__) | |
class ARTrainer(Seq2SeqTrainer): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs, preprocess_logits_for_metrics=None) | |
self.tb_writer = self.get_tb_writer() | |
self.original_data_collator = self.data_collator | |
def get_tb_writer(self): | |
for cb in self.callback_handler.callbacks: | |
if isinstance(cb, TensorBoardCallback): | |
return cb | |
return None | |
def log_results_to_tensorboard(self, output): | |
# TODO: we need to fix this which happens during the only eval option. | |
if self.tb_writer.tb_writer is None: | |
return | |
for i, (label, prediction) in enumerate( | |
zip(output.label_ids, output.predictions) | |
): | |
try: | |
total_text = "" | |
decoded_label = self.tokenizer.decode(label[label != -100]) | |
decoded_prediction = self.tokenizer.decode( | |
prediction[prediction != -100] | |
) | |
total_text += f"*** label ***: {decoded_label} \n" | |
total_text += f"*** prediction ***: {decoded_prediction}" | |
self.tb_writer.tb_writer.add_text( | |
f"sample_{i}", total_text, self.state.global_step | |
) | |
except OverflowError: | |
print("[ERROR] tokenization", prediction) | |
def get_train_dataloader(self) -> DataLoader: | |
self.data_collator = self.original_data_collator("train") | |
return super().get_train_dataloader() | |
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: | |
self.data_collator = self.original_data_collator("eval") | |
return super().get_eval_dataloader(eval_dataset) | |
def evaluate( | |
self, | |
eval_dataset: Optional[Dataset] = None, | |
ignore_keys: Optional[List[str]] = None, | |
metric_key_prefix: str = "eval", | |
**gen_kwargs, | |
) -> Dict[str, float]: | |
""" | |
Copied from | |
- https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py | |
- https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_seq2seq.py | |
with added tensorboard text logging. | |
""" | |
gen_kwargs = gen_kwargs.copy() | |
# Use legacy argument setting if a) the option is not explicitly passed; and b) the argument is set in the | |
# training args | |
if ( | |
gen_kwargs.get("max_length") is None | |
and gen_kwargs.get("max_new_tokens") is None | |
and self.args.generation_max_length is not None | |
): | |
gen_kwargs["max_length"] = self.args.generation_max_length | |
if ( | |
gen_kwargs.get("num_beams") is None | |
and self.args.generation_num_beams is not None | |
): | |
gen_kwargs["num_beams"] = self.args.generation_num_beams | |
# We don't want to drop samples in general | |
self.gather_function = self.accelerator.gather | |
self._gen_kwargs = gen_kwargs | |
# return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) | |
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset | |
if isinstance(eval_dataset, dict): | |
metrics = {} | |
for eval_dataset_name, _eval_dataset in eval_dataset.items(): | |
dataset_metrics = self.evaluate( | |
eval_dataset=_eval_dataset, | |
ignore_keys=ignore_keys, | |
metric_key_prefix=f"{metric_key_prefix}_{eval_dataset_name}", | |
) | |
metrics.update(dataset_metrics) | |
return metrics | |
# memory metrics - must set up as early as possible | |
self._memory_tracker.start() | |
eval_dataloader = self.get_eval_dataloader(eval_dataset) | |
# NOTE: no tpu | |
# if self.is_fsdp_xla_v2_enabled: | |
# eval_dataloader = tpu_spmd_dataloader(eval_dataloader) | |
start_time = time.time() | |
eval_loop = ( | |
self.prediction_loop | |
if self.args.use_legacy_prediction_loop | |
else self.evaluation_loop | |
) | |
output = eval_loop( | |
eval_dataloader, | |
description="Evaluation", | |
# No point gathering the predictions if there are no metrics, otherwise we defer to | |
# self.args.prediction_loss_only | |
prediction_loss_only=True if self.compute_metrics is None else None, | |
ignore_keys=ignore_keys, | |
metric_key_prefix=metric_key_prefix, | |
) | |
total_batch_size = self.args.eval_batch_size * self.args.world_size | |
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: | |
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] | |
output.metrics.update( | |
speed_metrics( | |
metric_key_prefix, | |
start_time, | |
num_samples=output.num_samples, | |
num_steps=math.ceil(output.num_samples / total_batch_size), | |
) | |
) | |
self.log(output.metrics) | |
# NOTE: no tpu | |
# if DebugOption.TPU_METRICS_DEBUG in self.args.debug: | |
# # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) | |
# xm.master_print(met.metrics_report()) | |
self.control = self.callback_handler.on_evaluate( | |
self.args, self.state, self.control, output.metrics | |
) | |
self._memory_tracker.stop_and_update_metrics(output.metrics) | |
# NOTE: text logging | |
if self.args.log_generated_texts: | |
self.log_results_to_tensorboard(output) | |
return output.metrics | |