File size: 12,814 Bytes
15bcbe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
# Copyright 2022 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Sequence to sequence model."""

from typing import Any, Callable, Dict, Tuple

from absl import logging
from flax import linen as nn
from flax.training import common_utils
import gin
import jax
import jax.numpy as jnp
import  metrics_summary
from transformer import decoder_stack
from transformer import metric_utils
from transformer import text_dataset
import numpy as np
import seqio


Array = jnp.ndarray
MetricsSummary = metrics_summary.MetricsSummary


# TODO(mrabe): Remove this function and find a better way to turn text metrics
# into text on tensorboard.
def process_summaries(vocab: seqio.Vocabulary,
                      met_summary: MetricsSummary,
                      mode: str) -> MetricsSummary:
  """Compute some additional summaries, and convert tokens to text.

  Args:
    vocab: The vocabulary to detokenize generated text.
    met_summary: The summary object to process.
    mode: The mode of the summary (e.g. "test", "train")

  Returns:
    The modified summary dictionary.
  """

  mdict = met_summary.current_metric_dict()

  # Calculate perplexity from the average nats_per_token over all replicas.
  # This has to be done here, because the perplexities themselves can't be
  # averaged in the usual way.
  if "nats_per_token" in mdict:
    nats_per_token = mdict["nats_per_token"].to_value()
    met_summary.add({"perplexity": np.exp(nats_per_token)})

  if mode == "generate" and "gen_tokens" in mdict:
    # Convert output tokens to example output text.
    # Write text to both the summary, and pretty-print to the log file.
    gen_toks = mdict["gen_tokens"].to_value()
    if np.ndim(gen_toks) != 2:
      raise ValueError("Unsupported shape for gen_tokens: %s" % gen_toks.shape)

    ntoks = gen_toks.shape[-1]
    gen_text = text_dataset.decode_tokens(gen_toks, vocab, max_length=ntoks)
    logging.info("Generated text = %s", gen_text)
    met_summary.add_text({"gen_text": gen_text})
    del mdict["gen_tokens"]   # Otherwise it will turn into a histogram.

  return met_summary


@gin.configurable
def process_summaries_function(vocab: seqio.Vocabulary) -> Callable[
    [MetricsSummary, str], MetricsSummary]:
  """Return a function that processes summaries with the given vocabulary."""
  # For use with training_loop.process_summaries_function
  def process_fn(met_summary: MetricsSummary, mode: str):
    return process_summaries(vocab, met_summary, mode)
  return process_fn


@gin.configurable
class DecoderOnlyLanguageModel(nn.Module):
  """Decoder only language modeling."""

  mode: str
  task_config: decoder_stack.TransformerTaskConfig = gin.REQUIRED
  decoder_factory: Callable[[], Any] = gin.REQUIRED

  sample_method: str = "sample"   # Can be {"sample", "greedy"}
  output_token_losses: bool = False

  def get_fake_input(self):
    """Returns a fake input for initialization of the appropriate shape."""
    b = self.task_config.batch_size
    fake_input_dict = {
        "targets": jnp.ones([b, self.task_config.sequence_length],
                            dtype=jnp.int32),
        "start_of_sequence": jnp.ones([b], dtype=jnp.bool_),
        "epoch": jnp.ones([b], dtype=jnp.int32),
    }
    if text_dataset.get_loss_mask_tokens(split=self.mode) != (None, None):
      # We are not adding the loss mask to the dummy input by default as it can
      # cause a slowdown during evaluation and perhaps inference.
      fake_input_dict["loss_mask"] = jnp.ones(
          [b, self.task_config.sequence_length], dtype=jnp.bool_)
    return fake_input_dict

  def metrics_summary_operations(self, aggregate_over: str) -> Dict[str, str]:
    """Summary operation to use for recorded metrics."""
    metric_ops = {
        "loss": "mean",
        "nats_per_token": "mean",
        "bits_per_token": "mean",
        "bits_per_char": "mean",
        "accuracy": "mean",
        "num_tokens": "mean",
        "num_chars_per_device": "mean",
        "num_chars_per_batch": "mean",
        "nonzero_tokens": "mean",
        "num_tokens_per_device": "mean",
        "num_tokens_per_batch": "mean",
        "epoch": "mean",
    }
    if aggregate_over == "steps":
      return metric_ops
    elif aggregate_over == "devices":
      # Ensure that statistics that refer to the total batch size stay constant
      # as TPU topologies change. For those we have to sum over devices, but
      # compute the mean over steps.
      metric_ops.update({
          "num_tokens_per_batch": "sum",
          "num_chars_per_batch": "sum",
          "loss": "sum"})
      return metric_ops
    else:
      raise ValueError("Don't know how to aggregate over: %s" % aggregate_over)

  def setup(self):
    self.decoder = self.decoder_factory(mode=self.mode,
                                        task_config=self.task_config)  # pytype: disable=wrong-keyword-args  # trace-all-classes

  def __call__(self, inputs: ...):
    task_config = self.task_config

    input_tokens = inputs["targets"]                  # [b, seq_len]
    start_of_sequence = inputs["start_of_sequence"]   # [b]
    epochs = inputs["epoch"]                          # [b]
    if "loss_mask" in inputs:
      loss_mask = inputs["loss_mask"]                 # [b, seq_len]
    else:
      loss_mask = jnp.ones((1, 1), dtype=jnp.bool_)

    input_tokens = jnp.asarray(input_tokens)
    assert input_tokens.ndim == 2
    assert input_tokens.shape[0] == task_config.batch_size
    assert input_tokens.shape[1] == task_config.sequence_length
    assert start_of_sequence.shape[0] == task_config.batch_size

    # Sanity check to avoid out-of-bounds on token lookup.
    input_tokens = input_tokens % task_config.vocab_size

    logging.info("langmodel: Compiling model for mode %s", self.mode)
    logging.info("langmodel: input_tokens = %r", input_tokens)
    logging.info("langmodel: start_of_sequece = %r", start_of_sequence)
    logging.info("langmodel: epochs = %r", epochs)

    # The target outputs are the next character in each sequence.
    # Shift tokens left and pad with a zero at the end.
    # TODO(delesley): We don't predict the first token of each sequence.
    target_tokens = jnp.pad(input_tokens[:, 1:], [(0, 0), (0, 1)])
    logging.info("langmodel: target_tokens = %r", target_tokens)

    # Invoke the decoder stack.
    # The decoder will return pre-softmax logits for the predicted targets.
    (logits, _, d_metrics) = self.decoder(input_tokens=input_tokens,
                                          target_tokens=target_tokens,
                                          start_of_sequence=start_of_sequence)

    # Softmax cross-entropy loss on target tokens.
    logits = nn.log_softmax(logits, axis=-1)   # (b, seq_len, vocab_size)
    logging.info("langmodel: logits = %r", logits)
    soft_targets = common_utils.onehot(target_tokens, task_config.vocab_size)
    logging.info("langmodel: soft_targets = %r", soft_targets)

    losses = -jnp.sum(soft_targets * logits, axis=-1)  # (b, seq_len)
    logging.info("langmodel: losses = %r", losses)

    # Don't predict null tokens which are past the end-of-sequence.
    # Also don't predict the 0 at the end of the sequence.
    # TODO(delesley): Predict the final end-of-sequence marker.
    loss_mask = jnp.logical_and(
        loss_mask,
        input_tokens > 0)
    loss_mask = jnp.logical_and(
        loss_mask,
        target_tokens > 0)
    logging.info("langmodel: loss_mask = %r", loss_mask)

    losses = jnp.where(loss_mask, losses, 0.0)  # (batch_size, seq_len)
    loss = jnp.sum(losses)  # total loss on device

    token_count = jnp.sum(loss_mask)  # tokens on device
    token_count_nz = token_count + 1.0e-6
    loss_per_token = loss / token_count_nz
    bits_per_token = loss_per_token * 1.442695  # log(e)/log(2)
    accuracy = metric_utils.compute_accuracy_sum(logits, target_tokens,
                                                 loss_mask)
    accuracy = accuracy / token_count_nz  # Percent correct.
    epoch = jnp.mean(epochs)

    if self.mode == "generate" and self.decoder.supports_generate():
      # Generate example text.
      logging.info("lang_model: text inference.")
      gen_tokens = self.generate(inputs, task_config.sequence_length)

      # Return generated text, along with vizualizations and histograms.
      metrics = {"gen_tokens": gen_tokens, **d_metrics}
      return (loss, metrics)

    # Just return metrics related to the loss.
    metrics = {
        "loss": loss,   # will be summed over devices
        "nats_per_token": (loss_per_token, token_count),
        "bits_per_token": (bits_per_token, token_count),
        "accuracy": (accuracy, token_count),
        "num_tokens_per_device": token_count,
        "num_tokens_per_batch": token_count,  # will be summed over devices
        "epoch": epoch,
    }

    # Compute bits per character if we have the number of characters.
    if "num_chars" in inputs:
      num_chars = jnp.sum(inputs["num_chars"])
      bits_per_char = loss / (num_chars + 1e-6) * 1.442695
      metrics["num_chars_per_device"] = num_chars
      metrics["num_chars_per_batch"] = num_chars  # will be summed over devices
      metrics["bits_per_char"] = (bits_per_char, num_chars)

    # Provided to make sure that the data pipeline and the the model agree
    # on the number of tokens with a loss.
    if "nonzero_tokens" in inputs:
      nonzero_tokens = jnp.sum(inputs["nonzero_tokens"])
      metrics["nonzero_tokens"] = nonzero_tokens

    if self.output_token_losses:
      metrics["token_losses"] = losses

    return (loss, metrics)

  def generate(self, inputs: ..., sequence_length: int) -> Array:
    """Generate an output sequence.

    Args:
      inputs: the same as argument to _call_.
      sequence_length: the length of sequence to generate.

    Returns:
      An array of generated tokens of shape (batch_size, sequence_length).
    """
    # TODO(delesley): Add support for passing the prefix as an argument.
    # TODO(delesley): Add support for temperature, gumbel softmax, beam search.

    batch_size = self.task_config.batch_size
    input_tokens = inputs["targets"]                  # [b,seq_len]
    start_of_sequence = inputs["start_of_sequence"]   # [b]

    # Initialize decoder.
    dstate = self.decoder.init_decoder_state(sequence_length,
                                             start_of_sequence)

    # TODO(delesley): Handle start-of-sequence in a better way.
    # There is no special token for start of sequence, so we grab the first
    # one from the ground-truth input data.
    first_token = input_tokens[:, 0:1]
    no_start_of_seq = jnp.array([False] * batch_size, dtype=jnp.bool_)
    sample_method = self.sample_method
    sample_prng = self.make_rng("sample")

    # Greedy autoregressive decoder function.
    def loop_fn(scan_state: Any, i: Array) -> Tuple[Any, Array]:
      prng = jax.random.fold_in(sample_prng, i)
      (dstate, input_token) = scan_state
      del i
      (logits, dstate, _) = self.decoder(input_tokens=input_token,
                                         target_tokens=None,
                                         start_of_sequence=no_start_of_seq,
                                         decoder_state=dstate)
      if sample_method == "sample":
        logging.info("Using categorical sampling.")
        output_token = jax.random.categorical(prng, logits, axis=-1)
      elif sample_method == "greedy":
        logging.info("Using greedy sampling.")
        output_token = jnp.argmax(logits, axis=-1)
      else:
        raise ValueError(f"Invalid sampling method: {sample_method}")
      logging.info("generate_loop_fn: output_token = %r", output_token)
      return ((dstate, output_token), output_token)

    # Scan over the sequence length.
    iterations = jnp.arange(sequence_length)
    initial_scan_state = (dstate, first_token)
    (_, output_tokens) = jax.lax.scan(loop_fn, initial_scan_state, iterations)
    logging.info("generate: output_tokens = %r", output_tokens)

    # Output_tokens has shape (sequence_length, batch_size, 1)
    assert output_tokens.shape == (sequence_length, batch_size, 1)
    output_tokens = jnp.reshape(
        output_tokens, (sequence_length, self.task_config.batch_size))
    output_tokens = output_tokens.transpose([1, 0])
    return output_tokens