Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Defines the translation task.""" | |
import dataclasses | |
import os | |
from typing import Optional | |
from absl import logging | |
import sacrebleu | |
import tensorflow as tf, tf_keras | |
import tensorflow_text as tftxt | |
from official.core import base_task | |
from official.core import config_definitions as cfg | |
from official.core import task_factory | |
from official.modeling.hyperparams import base_config | |
from official.nlp.data import data_loader_factory | |
from official.nlp.metrics import bleu | |
from official.nlp.modeling import models | |
def _pad_tensors_to_same_length(x, y): | |
"""Pad x and y so that the results have the same length (second dimension).""" | |
x_length = tf.shape(x)[1] | |
y_length = tf.shape(y)[1] | |
max_length = tf.maximum(x_length, y_length) | |
x = tf.pad(x, [[0, 0], [0, max_length - x_length], [0, 0]]) | |
y = tf.pad(y, [[0, 0], [0, max_length - y_length]]) | |
return x, y | |
def _padded_cross_entropy_loss(logits, labels, smoothing, vocab_size): | |
"""Calculate cross entropy loss while ignoring padding. | |
Args: | |
logits: Tensor of size [batch_size, length_logits, vocab_size] | |
labels: Tensor of size [batch_size, length_labels] | |
smoothing: Label smoothing constant, used to determine the on and off values | |
vocab_size: int size of the vocabulary | |
Returns: | |
Returns the cross entropy loss and weight tensors: float32 tensors with | |
shape [batch_size, max(length_logits, length_labels)] | |
""" | |
logits, labels = _pad_tensors_to_same_length(logits, labels) | |
# Calculate smoothing cross entropy | |
confidence = 1.0 - smoothing | |
low_confidence = (1.0 - confidence) / tf.cast(vocab_size - 1, tf.float32) | |
soft_targets = tf.one_hot( | |
tf.cast(labels, tf.int32), | |
depth=vocab_size, | |
on_value=confidence, | |
off_value=low_confidence) | |
xentropy = tf.nn.softmax_cross_entropy_with_logits( | |
logits=logits, labels=soft_targets) | |
# Calculate the best (lowest) possible value of cross entropy, and | |
# subtract from the cross entropy loss. | |
normalizing_constant = -( | |
confidence * tf.math.log(confidence) + tf.cast(vocab_size - 1, tf.float32) | |
* low_confidence * tf.math.log(low_confidence + 1e-20)) | |
xentropy -= normalizing_constant | |
weights = tf.cast(tf.not_equal(labels, 0), tf.float32) | |
return xentropy * weights, weights | |
class EncDecoder(base_config.Config): | |
"""Configurations for Encoder/Decoder.""" | |
num_layers: int = 6 | |
num_attention_heads: int = 8 | |
intermediate_size: int = 2048 | |
activation: str = "relu" | |
dropout_rate: float = 0.1 | |
attention_dropout_rate: float = 0.1 | |
intermediate_dropout: float = 0.1 | |
use_bias: bool = False | |
norm_first: bool = True | |
norm_epsilon: float = 1e-6 | |
class ModelConfig(base_config.Config): | |
"""A base Seq2Seq model configuration.""" | |
encoder: EncDecoder = dataclasses.field(default_factory=EncDecoder) | |
decoder: EncDecoder = dataclasses.field(default_factory=EncDecoder) | |
embedding_width: int = 512 | |
dropout_rate: float = 0.1 | |
# Decoding. | |
padded_decode: bool = False | |
decode_max_length: Optional[int] = None | |
beam_size: int = 4 | |
alpha: float = 0.6 | |
# Training. | |
label_smoothing: float = 0.1 | |
class TranslationConfig(cfg.TaskConfig): | |
"""The translation task config.""" | |
model: ModelConfig = dataclasses.field(default_factory=ModelConfig) | |
train_data: cfg.DataConfig = dataclasses.field(default_factory=cfg.DataConfig) | |
validation_data: cfg.DataConfig = dataclasses.field( | |
default_factory=cfg.DataConfig | |
) | |
# Tokenization | |
sentencepiece_model_path: str = "" | |
# Evaluation. | |
print_translations: Optional[bool] = None | |
def write_test_record(params, model_dir): | |
"""Writes the test input to a tfrecord.""" | |
# Get raw data from tfds. | |
params = params.replace(transform_and_batch=False) | |
dataset = data_loader_factory.get_data_loader(params).load() | |
references = [] | |
total_samples = 0 | |
output_file = os.path.join(model_dir, "eval.tf_record") | |
writer = tf.io.TFRecordWriter(output_file) | |
for d in dataset: | |
references.append(d[params.tgt_lang].numpy().decode()) | |
example = tf.train.Example( | |
features=tf.train.Features( | |
feature={ | |
"unique_id": tf.train.Feature( | |
int64_list=tf.train.Int64List(value=[total_samples])), | |
params.src_lang: tf.train.Feature( | |
bytes_list=tf.train.BytesList( | |
value=[d[params.src_lang].numpy()])), | |
params.tgt_lang: tf.train.Feature( | |
bytes_list=tf.train.BytesList( | |
value=[d[params.tgt_lang].numpy()])), | |
})) | |
writer.write(example.SerializeToString()) | |
total_samples += 1 | |
batch_size = params.global_batch_size | |
num_dummy_example = batch_size - total_samples % batch_size | |
for i in range(num_dummy_example): | |
example = tf.train.Example( | |
features=tf.train.Features( | |
feature={ | |
"unique_id": tf.train.Feature( | |
int64_list=tf.train.Int64List(value=[total_samples + i])), | |
params.src_lang: tf.train.Feature( | |
bytes_list=tf.train.BytesList(value=[b""])), | |
params.tgt_lang: tf.train.Feature( | |
bytes_list=tf.train.BytesList(value=[b""])), | |
})) | |
writer.write(example.SerializeToString()) | |
writer.close() | |
return references, output_file | |
class TranslationTask(base_task.Task): | |
"""A single-replica view of training procedure. | |
Tasks provide artifacts for training/evalution procedures, including | |
loading/iterating over Datasets, initializing the model, calculating the loss | |
and customized metrics with reduction. | |
""" | |
def __init__(self, params: cfg.TaskConfig, logging_dir=None, name=None): | |
super().__init__(params, logging_dir, name=name) | |
self._sentencepiece_model_path = params.sentencepiece_model_path | |
if params.sentencepiece_model_path: | |
self._sp_tokenizer = tftxt.SentencepieceTokenizer( | |
model=tf.io.gfile.GFile(params.sentencepiece_model_path, "rb").read(), | |
add_eos=True) | |
try: | |
empty_str_tokenized = self._sp_tokenizer.tokenize("").numpy() | |
except tf.errors.InternalError: | |
raise ValueError( | |
"EOS token not in tokenizer vocab." | |
"Please make sure the tokenizer generates a single token for an " | |
"empty string.") | |
self._eos_id = empty_str_tokenized.item() | |
self._vocab_size = self._sp_tokenizer.vocab_size().numpy() | |
else: | |
raise ValueError("Setencepiece model path not provided.") | |
if (params.validation_data.input_path or | |
params.validation_data.tfds_name) and self._logging_dir: | |
self._references, self._tf_record_input_path = write_test_record( | |
params.validation_data, self.logging_dir) | |
def build_model(self) -> tf_keras.Model: | |
"""Creates model architecture. | |
Returns: | |
A model instance. | |
""" | |
model_cfg = self.task_config.model | |
encoder_kwargs = model_cfg.encoder.as_dict() | |
encoder_layer = models.TransformerEncoder(**encoder_kwargs) | |
decoder_kwargs = model_cfg.decoder.as_dict() | |
decoder_layer = models.TransformerDecoder(**decoder_kwargs) | |
return models.Seq2SeqTransformer( | |
vocab_size=self._vocab_size, | |
embedding_width=model_cfg.embedding_width, | |
dropout_rate=model_cfg.dropout_rate, | |
padded_decode=model_cfg.padded_decode, | |
decode_max_length=model_cfg.decode_max_length, | |
beam_size=model_cfg.beam_size, | |
alpha=model_cfg.alpha, | |
encoder_layer=encoder_layer, | |
decoder_layer=decoder_layer, | |
eos_id=self._eos_id) | |
def build_inputs(self, | |
params: cfg.DataConfig, | |
input_context: Optional[tf.distribute.InputContext] = None): | |
"""Returns a dataset.""" | |
if params.is_training: | |
dataloader_params = params | |
else: | |
input_path = self._tf_record_input_path | |
# Read from padded tf records instead. | |
dataloader_params = params.replace( | |
input_path=input_path, | |
tfds_name="", | |
tfds_split="", | |
has_unique_id=True) | |
dataloader_params = dataloader_params.replace( | |
sentencepiece_model_path=self._sentencepiece_model_path) | |
return data_loader_factory.get_data_loader(dataloader_params).load( | |
input_context) | |
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor: | |
"""Standard interface to compute losses. | |
Args: | |
labels: optional label tensors. | |
model_outputs: a nested structure of output tensors. | |
aux_losses: auxiliary loss tensors, i.e. `losses` in keras.Model. | |
Returns: | |
The total loss tensor. | |
""" | |
del aux_losses | |
smoothing = self.task_config.model.label_smoothing | |
xentropy, weights = _padded_cross_entropy_loss(model_outputs, labels, | |
smoothing, self._vocab_size) | |
return tf.reduce_sum(xentropy) / tf.reduce_sum(weights) | |
def train_step(self, | |
inputs, | |
model: tf_keras.Model, | |
optimizer: tf_keras.optimizers.Optimizer, | |
metrics=None): | |
"""Does forward and backward. | |
With distribution strategies, this method runs on devices. | |
Args: | |
inputs: a dictionary of input tensors. | |
model: the model, forward pass definition. | |
optimizer: the optimizer for this training step. | |
metrics: a nested structure of metrics objects. | |
Returns: | |
A dictionary of logs. | |
""" | |
with tf.GradientTape() as tape: | |
outputs = model(inputs, training=True) | |
# Computes per-replica loss. | |
loss = self.build_losses(labels=inputs["targets"], model_outputs=outputs) | |
# Scales loss as the default gradients allreduce performs sum inside the | |
# optimizer. | |
scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync | |
# For mixed precision, when a LossScaleOptimizer is used, the loss is | |
# scaled to avoid numeric underflow. | |
if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer): | |
scaled_loss = optimizer.get_scaled_loss(scaled_loss) | |
tvars = model.trainable_variables | |
grads = tape.gradient(scaled_loss, tvars) | |
if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer): | |
grads = optimizer.get_unscaled_gradients(grads) | |
optimizer.apply_gradients(list(zip(grads, tvars))) | |
logs = {self.loss: loss} | |
if metrics: | |
self.process_metrics(metrics, inputs["targets"], outputs) | |
return logs | |
def validation_step(self, inputs, model: tf_keras.Model, metrics=None): | |
unique_ids = inputs.pop("unique_id") | |
# Validation loss | |
outputs = model(inputs, training=False) | |
# Computes per-replica loss to help understand if we are overfitting. | |
loss = self.build_losses(labels=inputs["targets"], model_outputs=outputs) | |
inputs.pop("targets") | |
# Beam search to calculate metrics. | |
model_outputs = model(inputs, training=False) | |
outputs = model_outputs | |
logs = { | |
self.loss: loss, | |
"inputs": inputs["inputs"], | |
"unique_ids": unique_ids, | |
} | |
logs.update(outputs) | |
return logs | |
def aggregate_logs(self, state=None, step_outputs=None): | |
"""Aggregates over logs returned from a validation step.""" | |
if state is None: | |
state = {} | |
for in_token_ids, out_token_ids, unique_ids in zip( | |
step_outputs["inputs"], | |
step_outputs["outputs"], | |
step_outputs["unique_ids"]): | |
for in_ids, out_ids, u_id in zip( | |
in_token_ids.numpy(), out_token_ids.numpy(), unique_ids.numpy()): | |
state[u_id] = (in_ids, out_ids) | |
return state | |
def reduce_aggregated_logs(self, aggregated_logs, global_step=None): | |
def _decode(ids): | |
return self._sp_tokenizer.detokenize(ids).numpy().decode() | |
def _trim_and_decode(ids): | |
"""Trim EOS and PAD tokens from ids, and decode to return a string.""" | |
try: | |
index = list(ids).index(self._eos_id) | |
return _decode(ids[:index]) | |
except ValueError: # No EOS found in sequence | |
return _decode(ids) | |
translations = [] | |
for u_id in sorted(aggregated_logs): | |
if u_id >= len(self._references): | |
continue | |
src = _trim_and_decode(aggregated_logs[u_id][0]) | |
translation = _trim_and_decode(aggregated_logs[u_id][1]) | |
translations.append(translation) | |
if self.task_config.print_translations: | |
# Deccoding the in_ids to reflect what the model sees. | |
logging.info("Translating:\n\tInput: %s\n\tOutput: %s\n\tReference: %s", | |
src, translation, self._references[u_id]) | |
sacrebleu_score = sacrebleu.corpus_bleu( | |
translations, [self._references]).score | |
bleu_score = bleu.bleu_on_list(self._references, translations) | |
return {"sacrebleu_score": sacrebleu_score, | |
"bleu_score": bleu_score} | |