File size: 13,566 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines the translation task."""
import dataclasses
import os
from typing import Optional

from absl import logging
import sacrebleu
import tensorflow as tf, tf_keras
import tensorflow_text as tftxt

from official.core import base_task
from official.core import config_definitions as cfg
from official.core import task_factory
from official.modeling.hyperparams import base_config
from official.nlp.data import data_loader_factory
from official.nlp.metrics import bleu
from official.nlp.modeling import models


def _pad_tensors_to_same_length(x, y):
  """Pad x and y so that the results have the same length (second dimension)."""
  x_length = tf.shape(x)[1]
  y_length = tf.shape(y)[1]

  max_length = tf.maximum(x_length, y_length)

  x = tf.pad(x, [[0, 0], [0, max_length - x_length], [0, 0]])
  y = tf.pad(y, [[0, 0], [0, max_length - y_length]])
  return x, y


def _padded_cross_entropy_loss(logits, labels, smoothing, vocab_size):
  """Calculate cross entropy loss while ignoring padding.

  Args:
    logits: Tensor of size [batch_size, length_logits, vocab_size]
    labels: Tensor of size [batch_size, length_labels]
    smoothing: Label smoothing constant, used to determine the on and off values
    vocab_size: int size of the vocabulary

  Returns:
    Returns the cross entropy loss and weight tensors: float32 tensors with
      shape [batch_size, max(length_logits, length_labels)]
  """
  logits, labels = _pad_tensors_to_same_length(logits, labels)

  # Calculate smoothing cross entropy
  confidence = 1.0 - smoothing
  low_confidence = (1.0 - confidence) / tf.cast(vocab_size - 1, tf.float32)
  soft_targets = tf.one_hot(
      tf.cast(labels, tf.int32),
      depth=vocab_size,
      on_value=confidence,
      off_value=low_confidence)
  xentropy = tf.nn.softmax_cross_entropy_with_logits(
      logits=logits, labels=soft_targets)

  # Calculate the best (lowest) possible value of cross entropy, and
  # subtract from the cross entropy loss.
  normalizing_constant = -(
      confidence * tf.math.log(confidence) + tf.cast(vocab_size - 1, tf.float32)
      * low_confidence * tf.math.log(low_confidence + 1e-20))
  xentropy -= normalizing_constant

  weights = tf.cast(tf.not_equal(labels, 0), tf.float32)
  return xentropy * weights, weights


@dataclasses.dataclass
class EncDecoder(base_config.Config):
  """Configurations for Encoder/Decoder."""
  num_layers: int = 6
  num_attention_heads: int = 8
  intermediate_size: int = 2048
  activation: str = "relu"
  dropout_rate: float = 0.1
  attention_dropout_rate: float = 0.1
  intermediate_dropout: float = 0.1
  use_bias: bool = False
  norm_first: bool = True
  norm_epsilon: float = 1e-6


@dataclasses.dataclass
class ModelConfig(base_config.Config):
  """A base Seq2Seq model configuration."""
  encoder: EncDecoder = dataclasses.field(default_factory=EncDecoder)
  decoder: EncDecoder = dataclasses.field(default_factory=EncDecoder)

  embedding_width: int = 512
  dropout_rate: float = 0.1

  # Decoding.
  padded_decode: bool = False
  decode_max_length: Optional[int] = None
  beam_size: int = 4
  alpha: float = 0.6

  # Training.
  label_smoothing: float = 0.1


@dataclasses.dataclass
class TranslationConfig(cfg.TaskConfig):
  """The translation task config."""
  model: ModelConfig = dataclasses.field(default_factory=ModelConfig)
  train_data: cfg.DataConfig = dataclasses.field(default_factory=cfg.DataConfig)
  validation_data: cfg.DataConfig = dataclasses.field(
      default_factory=cfg.DataConfig
  )
  # Tokenization
  sentencepiece_model_path: str = ""
  # Evaluation.
  print_translations: Optional[bool] = None


def write_test_record(params, model_dir):
  """Writes the test input to a tfrecord."""
  # Get raw data from tfds.
  params = params.replace(transform_and_batch=False)
  dataset = data_loader_factory.get_data_loader(params).load()
  references = []
  total_samples = 0
  output_file = os.path.join(model_dir, "eval.tf_record")
  writer = tf.io.TFRecordWriter(output_file)
  for d in dataset:
    references.append(d[params.tgt_lang].numpy().decode())
    example = tf.train.Example(
        features=tf.train.Features(
            feature={
                "unique_id": tf.train.Feature(
                    int64_list=tf.train.Int64List(value=[total_samples])),
                params.src_lang: tf.train.Feature(
                    bytes_list=tf.train.BytesList(
                        value=[d[params.src_lang].numpy()])),
                params.tgt_lang: tf.train.Feature(
                    bytes_list=tf.train.BytesList(
                        value=[d[params.tgt_lang].numpy()])),
            }))
    writer.write(example.SerializeToString())
    total_samples += 1
  batch_size = params.global_batch_size
  num_dummy_example = batch_size - total_samples % batch_size
  for i in range(num_dummy_example):
    example = tf.train.Example(
        features=tf.train.Features(
            feature={
                "unique_id": tf.train.Feature(
                    int64_list=tf.train.Int64List(value=[total_samples + i])),
                params.src_lang: tf.train.Feature(
                    bytes_list=tf.train.BytesList(value=[b""])),
                params.tgt_lang: tf.train.Feature(
                    bytes_list=tf.train.BytesList(value=[b""])),
            }))
    writer.write(example.SerializeToString())
  writer.close()
  return references, output_file


@task_factory.register_task_cls(TranslationConfig)
class TranslationTask(base_task.Task):
  """A single-replica view of training procedure.

  Tasks provide artifacts for training/evalution procedures, including
  loading/iterating over Datasets, initializing the model, calculating the loss
  and customized metrics with reduction.
  """

  def __init__(self, params: cfg.TaskConfig, logging_dir=None, name=None):
    super().__init__(params, logging_dir, name=name)
    self._sentencepiece_model_path = params.sentencepiece_model_path
    if params.sentencepiece_model_path:
      self._sp_tokenizer = tftxt.SentencepieceTokenizer(
          model=tf.io.gfile.GFile(params.sentencepiece_model_path, "rb").read(),
          add_eos=True)
      try:
        empty_str_tokenized = self._sp_tokenizer.tokenize("").numpy()
      except tf.errors.InternalError:
        raise ValueError(
            "EOS token not in tokenizer vocab."
            "Please make sure the tokenizer generates a single token for an "
            "empty string.")
      self._eos_id = empty_str_tokenized.item()
      self._vocab_size = self._sp_tokenizer.vocab_size().numpy()
    else:
      raise ValueError("Setencepiece model path not provided.")
    if (params.validation_data.input_path or
        params.validation_data.tfds_name) and self._logging_dir:
      self._references, self._tf_record_input_path = write_test_record(
          params.validation_data, self.logging_dir)

  def build_model(self) -> tf_keras.Model:
    """Creates model architecture.

    Returns:
      A model instance.
    """
    model_cfg = self.task_config.model
    encoder_kwargs = model_cfg.encoder.as_dict()
    encoder_layer = models.TransformerEncoder(**encoder_kwargs)
    decoder_kwargs = model_cfg.decoder.as_dict()
    decoder_layer = models.TransformerDecoder(**decoder_kwargs)

    return models.Seq2SeqTransformer(
        vocab_size=self._vocab_size,
        embedding_width=model_cfg.embedding_width,
        dropout_rate=model_cfg.dropout_rate,
        padded_decode=model_cfg.padded_decode,
        decode_max_length=model_cfg.decode_max_length,
        beam_size=model_cfg.beam_size,
        alpha=model_cfg.alpha,
        encoder_layer=encoder_layer,
        decoder_layer=decoder_layer,
        eos_id=self._eos_id)

  def build_inputs(self,
                   params: cfg.DataConfig,
                   input_context: Optional[tf.distribute.InputContext] = None):
    """Returns a dataset."""
    if params.is_training:
      dataloader_params = params
    else:
      input_path = self._tf_record_input_path
      # Read from padded tf records instead.
      dataloader_params = params.replace(
          input_path=input_path,
          tfds_name="",
          tfds_split="",
          has_unique_id=True)
    dataloader_params = dataloader_params.replace(
        sentencepiece_model_path=self._sentencepiece_model_path)
    return data_loader_factory.get_data_loader(dataloader_params).load(
        input_context)

  def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
    """Standard interface to compute losses.

    Args:
      labels: optional label tensors.
      model_outputs: a nested structure of output tensors.
      aux_losses: auxiliary loss tensors, i.e. `losses` in keras.Model.

    Returns:
      The total loss tensor.
    """
    del aux_losses

    smoothing = self.task_config.model.label_smoothing
    xentropy, weights = _padded_cross_entropy_loss(model_outputs, labels,
                                                   smoothing, self._vocab_size)
    return tf.reduce_sum(xentropy) / tf.reduce_sum(weights)

  def train_step(self,
                 inputs,
                 model: tf_keras.Model,
                 optimizer: tf_keras.optimizers.Optimizer,
                 metrics=None):
    """Does forward and backward.

    With distribution strategies, this method runs on devices.

    Args:
      inputs: a dictionary of input tensors.
      model: the model, forward pass definition.
      optimizer: the optimizer for this training step.
      metrics: a nested structure of metrics objects.

    Returns:
      A dictionary of logs.
    """
    with tf.GradientTape() as tape:
      outputs = model(inputs, training=True)
      # Computes per-replica loss.
      loss = self.build_losses(labels=inputs["targets"], model_outputs=outputs)
      # Scales loss as the default gradients allreduce performs sum inside the
      # optimizer.
      scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync

      # For mixed precision, when a LossScaleOptimizer is used, the loss is
      # scaled to avoid numeric underflow.
      if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
        scaled_loss = optimizer.get_scaled_loss(scaled_loss)

    tvars = model.trainable_variables
    grads = tape.gradient(scaled_loss, tvars)

    if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
      grads = optimizer.get_unscaled_gradients(grads)
    optimizer.apply_gradients(list(zip(grads, tvars)))
    logs = {self.loss: loss}
    if metrics:
      self.process_metrics(metrics, inputs["targets"], outputs)
    return logs

  def validation_step(self, inputs, model: tf_keras.Model, metrics=None):
    unique_ids = inputs.pop("unique_id")
    # Validation loss
    outputs = model(inputs, training=False)
    # Computes per-replica loss to help understand if we are overfitting.
    loss = self.build_losses(labels=inputs["targets"], model_outputs=outputs)
    inputs.pop("targets")
    # Beam search to calculate metrics.
    model_outputs = model(inputs, training=False)
    outputs = model_outputs
    logs = {
        self.loss: loss,
        "inputs": inputs["inputs"],
        "unique_ids": unique_ids,
    }
    logs.update(outputs)
    return logs

  def aggregate_logs(self, state=None, step_outputs=None):
    """Aggregates over logs returned from a validation step."""
    if state is None:
      state = {}

    for in_token_ids, out_token_ids, unique_ids in zip(
        step_outputs["inputs"],
        step_outputs["outputs"],
        step_outputs["unique_ids"]):
      for in_ids, out_ids, u_id in zip(
          in_token_ids.numpy(), out_token_ids.numpy(), unique_ids.numpy()):
        state[u_id] = (in_ids, out_ids)
    return state

  def reduce_aggregated_logs(self, aggregated_logs, global_step=None):

    def _decode(ids):
      return self._sp_tokenizer.detokenize(ids).numpy().decode()

    def _trim_and_decode(ids):
      """Trim EOS and PAD tokens from ids, and decode to return a string."""
      try:
        index = list(ids).index(self._eos_id)
        return _decode(ids[:index])
      except ValueError:  # No EOS found in sequence
        return _decode(ids)

    translations = []
    for u_id in sorted(aggregated_logs):
      if u_id >= len(self._references):
        continue
      src = _trim_and_decode(aggregated_logs[u_id][0])
      translation = _trim_and_decode(aggregated_logs[u_id][1])
      translations.append(translation)
      if self.task_config.print_translations:
        # Deccoding the in_ids to reflect what the model sees.
        logging.info("Translating:\n\tInput: %s\n\tOutput: %s\n\tReference: %s",
                     src, translation, self._references[u_id])
    sacrebleu_score = sacrebleu.corpus_bleu(
        translations, [self._references]).score
    bleu_score = bleu.bleu_on_list(self._references, translations)
    return {"sacrebleu_score": sacrebleu_score,
            "bleu_score": bleu_score}