# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Defines the translation task.""" import dataclasses import os from typing import Optional from absl import logging import sacrebleu import tensorflow as tf, tf_keras import tensorflow_text as tftxt from official.core import base_task from official.core import config_definitions as cfg from official.core import task_factory from official.modeling.hyperparams import base_config from official.nlp.data import data_loader_factory from official.nlp.metrics import bleu from official.nlp.modeling import models def _pad_tensors_to_same_length(x, y): """Pad x and y so that the results have the same length (second dimension).""" x_length = tf.shape(x)[1] y_length = tf.shape(y)[1] max_length = tf.maximum(x_length, y_length) x = tf.pad(x, [[0, 0], [0, max_length - x_length], [0, 0]]) y = tf.pad(y, [[0, 0], [0, max_length - y_length]]) return x, y def _padded_cross_entropy_loss(logits, labels, smoothing, vocab_size): """Calculate cross entropy loss while ignoring padding. Args: logits: Tensor of size [batch_size, length_logits, vocab_size] labels: Tensor of size [batch_size, length_labels] smoothing: Label smoothing constant, used to determine the on and off values vocab_size: int size of the vocabulary Returns: Returns the cross entropy loss and weight tensors: float32 tensors with shape [batch_size, max(length_logits, length_labels)] """ logits, labels = _pad_tensors_to_same_length(logits, labels) # Calculate smoothing cross entropy confidence = 1.0 - smoothing low_confidence = (1.0 - confidence) / tf.cast(vocab_size - 1, tf.float32) soft_targets = tf.one_hot( tf.cast(labels, tf.int32), depth=vocab_size, on_value=confidence, off_value=low_confidence) xentropy = tf.nn.softmax_cross_entropy_with_logits( logits=logits, labels=soft_targets) # Calculate the best (lowest) possible value of cross entropy, and # subtract from the cross entropy loss. normalizing_constant = -( confidence * tf.math.log(confidence) + tf.cast(vocab_size - 1, tf.float32) * low_confidence * tf.math.log(low_confidence + 1e-20)) xentropy -= normalizing_constant weights = tf.cast(tf.not_equal(labels, 0), tf.float32) return xentropy * weights, weights @dataclasses.dataclass class EncDecoder(base_config.Config): """Configurations for Encoder/Decoder.""" num_layers: int = 6 num_attention_heads: int = 8 intermediate_size: int = 2048 activation: str = "relu" dropout_rate: float = 0.1 attention_dropout_rate: float = 0.1 intermediate_dropout: float = 0.1 use_bias: bool = False norm_first: bool = True norm_epsilon: float = 1e-6 @dataclasses.dataclass class ModelConfig(base_config.Config): """A base Seq2Seq model configuration.""" encoder: EncDecoder = dataclasses.field(default_factory=EncDecoder) decoder: EncDecoder = dataclasses.field(default_factory=EncDecoder) embedding_width: int = 512 dropout_rate: float = 0.1 # Decoding. padded_decode: bool = False decode_max_length: Optional[int] = None beam_size: int = 4 alpha: float = 0.6 # Training. label_smoothing: float = 0.1 @dataclasses.dataclass class TranslationConfig(cfg.TaskConfig): """The translation task config.""" model: ModelConfig = dataclasses.field(default_factory=ModelConfig) train_data: cfg.DataConfig = dataclasses.field(default_factory=cfg.DataConfig) validation_data: cfg.DataConfig = dataclasses.field( default_factory=cfg.DataConfig ) # Tokenization sentencepiece_model_path: str = "" # Evaluation. print_translations: Optional[bool] = None def write_test_record(params, model_dir): """Writes the test input to a tfrecord.""" # Get raw data from tfds. params = params.replace(transform_and_batch=False) dataset = data_loader_factory.get_data_loader(params).load() references = [] total_samples = 0 output_file = os.path.join(model_dir, "eval.tf_record") writer = tf.io.TFRecordWriter(output_file) for d in dataset: references.append(d[params.tgt_lang].numpy().decode()) example = tf.train.Example( features=tf.train.Features( feature={ "unique_id": tf.train.Feature( int64_list=tf.train.Int64List(value=[total_samples])), params.src_lang: tf.train.Feature( bytes_list=tf.train.BytesList( value=[d[params.src_lang].numpy()])), params.tgt_lang: tf.train.Feature( bytes_list=tf.train.BytesList( value=[d[params.tgt_lang].numpy()])), })) writer.write(example.SerializeToString()) total_samples += 1 batch_size = params.global_batch_size num_dummy_example = batch_size - total_samples % batch_size for i in range(num_dummy_example): example = tf.train.Example( features=tf.train.Features( feature={ "unique_id": tf.train.Feature( int64_list=tf.train.Int64List(value=[total_samples + i])), params.src_lang: tf.train.Feature( bytes_list=tf.train.BytesList(value=[b""])), params.tgt_lang: tf.train.Feature( bytes_list=tf.train.BytesList(value=[b""])), })) writer.write(example.SerializeToString()) writer.close() return references, output_file @task_factory.register_task_cls(TranslationConfig) class TranslationTask(base_task.Task): """A single-replica view of training procedure. Tasks provide artifacts for training/evalution procedures, including loading/iterating over Datasets, initializing the model, calculating the loss and customized metrics with reduction. """ def __init__(self, params: cfg.TaskConfig, logging_dir=None, name=None): super().__init__(params, logging_dir, name=name) self._sentencepiece_model_path = params.sentencepiece_model_path if params.sentencepiece_model_path: self._sp_tokenizer = tftxt.SentencepieceTokenizer( model=tf.io.gfile.GFile(params.sentencepiece_model_path, "rb").read(), add_eos=True) try: empty_str_tokenized = self._sp_tokenizer.tokenize("").numpy() except tf.errors.InternalError: raise ValueError( "EOS token not in tokenizer vocab." "Please make sure the tokenizer generates a single token for an " "empty string.") self._eos_id = empty_str_tokenized.item() self._vocab_size = self._sp_tokenizer.vocab_size().numpy() else: raise ValueError("Setencepiece model path not provided.") if (params.validation_data.input_path or params.validation_data.tfds_name) and self._logging_dir: self._references, self._tf_record_input_path = write_test_record( params.validation_data, self.logging_dir) def build_model(self) -> tf_keras.Model: """Creates model architecture. Returns: A model instance. """ model_cfg = self.task_config.model encoder_kwargs = model_cfg.encoder.as_dict() encoder_layer = models.TransformerEncoder(**encoder_kwargs) decoder_kwargs = model_cfg.decoder.as_dict() decoder_layer = models.TransformerDecoder(**decoder_kwargs) return models.Seq2SeqTransformer( vocab_size=self._vocab_size, embedding_width=model_cfg.embedding_width, dropout_rate=model_cfg.dropout_rate, padded_decode=model_cfg.padded_decode, decode_max_length=model_cfg.decode_max_length, beam_size=model_cfg.beam_size, alpha=model_cfg.alpha, encoder_layer=encoder_layer, decoder_layer=decoder_layer, eos_id=self._eos_id) def build_inputs(self, params: cfg.DataConfig, input_context: Optional[tf.distribute.InputContext] = None): """Returns a dataset.""" if params.is_training: dataloader_params = params else: input_path = self._tf_record_input_path # Read from padded tf records instead. dataloader_params = params.replace( input_path=input_path, tfds_name="", tfds_split="", has_unique_id=True) dataloader_params = dataloader_params.replace( sentencepiece_model_path=self._sentencepiece_model_path) return data_loader_factory.get_data_loader(dataloader_params).load( input_context) def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor: """Standard interface to compute losses. Args: labels: optional label tensors. model_outputs: a nested structure of output tensors. aux_losses: auxiliary loss tensors, i.e. `losses` in keras.Model. Returns: The total loss tensor. """ del aux_losses smoothing = self.task_config.model.label_smoothing xentropy, weights = _padded_cross_entropy_loss(model_outputs, labels, smoothing, self._vocab_size) return tf.reduce_sum(xentropy) / tf.reduce_sum(weights) def train_step(self, inputs, model: tf_keras.Model, optimizer: tf_keras.optimizers.Optimizer, metrics=None): """Does forward and backward. With distribution strategies, this method runs on devices. Args: inputs: a dictionary of input tensors. model: the model, forward pass definition. optimizer: the optimizer for this training step. metrics: a nested structure of metrics objects. Returns: A dictionary of logs. """ with tf.GradientTape() as tape: outputs = model(inputs, training=True) # Computes per-replica loss. loss = self.build_losses(labels=inputs["targets"], model_outputs=outputs) # Scales loss as the default gradients allreduce performs sum inside the # optimizer. scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync # For mixed precision, when a LossScaleOptimizer is used, the loss is # scaled to avoid numeric underflow. if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer): scaled_loss = optimizer.get_scaled_loss(scaled_loss) tvars = model.trainable_variables grads = tape.gradient(scaled_loss, tvars) if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer): grads = optimizer.get_unscaled_gradients(grads) optimizer.apply_gradients(list(zip(grads, tvars))) logs = {self.loss: loss} if metrics: self.process_metrics(metrics, inputs["targets"], outputs) return logs def validation_step(self, inputs, model: tf_keras.Model, metrics=None): unique_ids = inputs.pop("unique_id") # Validation loss outputs = model(inputs, training=False) # Computes per-replica loss to help understand if we are overfitting. loss = self.build_losses(labels=inputs["targets"], model_outputs=outputs) inputs.pop("targets") # Beam search to calculate metrics. model_outputs = model(inputs, training=False) outputs = model_outputs logs = { self.loss: loss, "inputs": inputs["inputs"], "unique_ids": unique_ids, } logs.update(outputs) return logs def aggregate_logs(self, state=None, step_outputs=None): """Aggregates over logs returned from a validation step.""" if state is None: state = {} for in_token_ids, out_token_ids, unique_ids in zip( step_outputs["inputs"], step_outputs["outputs"], step_outputs["unique_ids"]): for in_ids, out_ids, u_id in zip( in_token_ids.numpy(), out_token_ids.numpy(), unique_ids.numpy()): state[u_id] = (in_ids, out_ids) return state def reduce_aggregated_logs(self, aggregated_logs, global_step=None): def _decode(ids): return self._sp_tokenizer.detokenize(ids).numpy().decode() def _trim_and_decode(ids): """Trim EOS and PAD tokens from ids, and decode to return a string.""" try: index = list(ids).index(self._eos_id) return _decode(ids[:index]) except ValueError: # No EOS found in sequence return _decode(ids) translations = [] for u_id in sorted(aggregated_logs): if u_id >= len(self._references): continue src = _trim_and_decode(aggregated_logs[u_id][0]) translation = _trim_and_decode(aggregated_logs[u_id][1]) translations.append(translation) if self.task_config.print_translations: # Deccoding the in_ids to reflect what the model sees. logging.info("Translating:\n\tInput: %s\n\tOutput: %s\n\tReference: %s", src, translation, self._references[u_id]) sacrebleu_score = sacrebleu.corpus_bleu( translations, [self._references]).score bleu_score = bleu.bleu_on_list(self._references, translations) return {"sacrebleu_score": sacrebleu_score, "bleu_score": bleu_score}