Spaces:
Sleeping
Sleeping
File size: 6,165 Bytes
17ff0d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import math
import time
from typing import Dict, List, Optional
from torch.utils.data import DataLoader, Dataset
from transformers import Seq2SeqTrainer
from transformers.integrations import TensorBoardCallback
from transformers.trainer_utils import speed_metrics
from transformers.utils import logging
skip_first_batches = None
IS_SAGEMAKER_MP_POST_1_10 = False
GENERATION_RESULTS = "generated"
logger = logging.get_logger(__name__)
class ARTrainer(Seq2SeqTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, preprocess_logits_for_metrics=None)
self.tb_writer = self.get_tb_writer()
self.original_data_collator = self.data_collator
def get_tb_writer(self):
for cb in self.callback_handler.callbacks:
if isinstance(cb, TensorBoardCallback):
return cb
return None
def log_results_to_tensorboard(self, output):
# TODO: we need to fix this which happens during the only eval option.
if self.tb_writer.tb_writer is None:
return
for i, (label, prediction) in enumerate(
zip(output.label_ids, output.predictions)
):
try:
total_text = ""
decoded_label = self.tokenizer.decode(label[label != -100])
decoded_prediction = self.tokenizer.decode(
prediction[prediction != -100]
)
total_text += f"*** label ***: {decoded_label} \n"
total_text += f"*** prediction ***: {decoded_prediction}"
self.tb_writer.tb_writer.add_text(
f"sample_{i}", total_text, self.state.global_step
)
except OverflowError:
print("[ERROR] tokenization", prediction)
def get_train_dataloader(self) -> DataLoader:
self.data_collator = self.original_data_collator("train")
return super().get_train_dataloader()
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
self.data_collator = self.original_data_collator("eval")
return super().get_eval_dataloader(eval_dataset)
def evaluate(
self,
eval_dataset: Optional[Dataset] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
**gen_kwargs,
) -> Dict[str, float]:
"""
Copied from
- https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py
- https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_seq2seq.py
with added tensorboard text logging.
"""
gen_kwargs = gen_kwargs.copy()
# Use legacy argument setting if a) the option is not explicitly passed; and b) the argument is set in the
# training args
if (
gen_kwargs.get("max_length") is None
and gen_kwargs.get("max_new_tokens") is None
and self.args.generation_max_length is not None
):
gen_kwargs["max_length"] = self.args.generation_max_length
if (
gen_kwargs.get("num_beams") is None
and self.args.generation_num_beams is not None
):
gen_kwargs["num_beams"] = self.args.generation_num_beams
# We don't want to drop samples in general
self.gather_function = self.accelerator.gather
self._gen_kwargs = gen_kwargs
# return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
if isinstance(eval_dataset, dict):
metrics = {}
for eval_dataset_name, _eval_dataset in eval_dataset.items():
dataset_metrics = self.evaluate(
eval_dataset=_eval_dataset,
ignore_keys=ignore_keys,
metric_key_prefix=f"{metric_key_prefix}_{eval_dataset_name}",
)
metrics.update(dataset_metrics)
return metrics
# memory metrics - must set up as early as possible
self._memory_tracker.start()
eval_dataloader = self.get_eval_dataloader(eval_dataset)
# NOTE: no tpu
# if self.is_fsdp_xla_v2_enabled:
# eval_dataloader = tpu_spmd_dataloader(eval_dataloader)
start_time = time.time()
eval_loop = (
self.prediction_loop
if self.args.use_legacy_prediction_loop
else self.evaluation_loop
)
output = eval_loop(
eval_dataloader,
description="Evaluation",
# No point gathering the predictions if there are no metrics, otherwise we defer to
# self.args.prediction_loss_only
prediction_loss_only=True if self.compute_metrics is None else None,
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
)
total_batch_size = self.args.eval_batch_size * self.args.world_size
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
output.metrics.update(
speed_metrics(
metric_key_prefix,
start_time,
num_samples=output.num_samples,
num_steps=math.ceil(output.num_samples / total_batch_size),
)
)
self.log(output.metrics)
# NOTE: no tpu
# if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
# # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
# xm.master_print(met.metrics_report())
self.control = self.callback_handler.on_evaluate(
self.args, self.state, self.control, output.metrics
)
self._memory_tracker.stop_and_update_metrics(output.metrics)
# NOTE: text logging
if self.args.log_generated_texts:
self.log_results_to_tensorboard(output)
return output.metrics
|