deanna-emery's picture
updates
93528c6
raw
history blame
11.3 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class for Decoding Strategies (beam_search, top_k, top_p and greedy)."""
import abc
from typing import Any, Callable, Dict, Optional, Tuple
import tensorflow as tf, tf_keras
from tensorflow.python.framework import dtypes
from official.modeling import tf_utils
Output = Tuple[tf.Tensor, tf.Tensor, Optional[tf.Tensor]]
InternalState = Tuple[tf.Tensor, tf.Tensor, tf.Tensor, Dict]
InitialState = Tuple[Dict[str, Any], Dict[str, Any]]
class StateKeys:
"""Keys to dictionary storing the state of Decoding loop."""
# Variable storing the loop index.
CUR_INDEX = "CUR_INDEX"
# Top sequences that are alive for each batch item. Alive sequences are ones
# that have not generated an EOS token. Sequences that reach EOS are marked as
# finished and moved to the FINISHED_SEQ tensor.
# Has shape [batch_size, beam_size, CUR_INDEX + 1] for SequenceBeamSearch and
# [batch_size, CUR_INDEX + 1] otherwise.
ALIVE_SEQ = "ALIVE_SEQ"
# Log probabilities of each alive sequence. Shape [batch_size, beam_size]
ALIVE_LOG_PROBS = "ALIVE_LOG_PROBS"
# Dictionary of cached values for each alive sequence. The cache stores
# the encoder output, attention bias, and the decoder attention output from
# the previous iteration.
ALIVE_CACHE = "ALIVE_CACHE"
# The initial model state/cache after model processing the initial token.
# The cache will be filled if extra_cache_output is true.
INITIAL_OUTPUT_CACHE = "INITIAL_OUTPUT_CACHE"
# Top finished sequences for each batch item.
# Has shape [batch_size, beam_size, CUR_INDEX + 1]. Sequences that are
# shorter than CUR_INDEX + 1 are padded with 0s.
FINISHED_SEQ = "FINISHED_SEQ"
# Scores for each finished sequence. Score = log probability / length norm
# Shape [batch_size, beam_size]
FINISHED_SCORES = "FINISHED_SCORES"
# Flags indicating which sequences in the finished sequences are finished.
# At the beginning, all of the sequences in FINISHED_SEQ are filler values.
# True -> finished sequence, False -> filler. Shape [batch_size, beam_size]
FINISHED_FLAGS = "FINISHED_FLAGS"
def log_prob_from_logits(logits):
return logits - tf.reduce_logsumexp(logits, axis=-1, keepdims=True)
def shape_list(tensor):
"""Return a list of the tensor's shape, and ensure no None values in list."""
return tf_utils.get_shape_list(tensor)
def get_shape_keep_last_dim(tensor):
shape_list_obj = shape_list(tensor)
for i in range(len(shape_list_obj) - 1):
shape_list_obj[i] = None
if isinstance(shape_list_obj[-1], tf.Tensor):
shape_list_obj[-1] = None
return tf.TensorShape(shape_list_obj)
def expand_to_same_rank(tensor, target):
"""Expands a given tensor to target's rank to be broadcastable.
Args:
tensor: input tensor to tile. Shape: [b, d1, ..., da]
target: target tensor. Shape: [b, d1, ..., da, ..., dn]
Returns:
Tiled tensor of shape [b, d1, ..., da, 1, ..., 1] with same rank of target
Raises:
ValueError, if the shape rank of rank tensor/target is None.
"""
if tensor.shape.rank is None:
raise ValueError("Expect rank for tensor shape, but got None.")
if target.shape.rank is None:
raise ValueError("Expect rank for target shape, but got None.")
with tf.name_scope("expand_rank"):
diff_rank = target.shape.rank - tensor.shape.rank
for _ in range(diff_rank):
tensor = tf.expand_dims(tensor, -1)
return tensor
class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
"""A base class for the API required for decoding (go/decoding-tf-nlp)."""
def __init__(self,
length_normalization_fn: Callable[[int, tf.DType], float],
dtype: tf.DType = tf.float32,
decoding_name: Optional[str] = None,
extra_cache_output: bool = False):
"""Initialize the Decoding Module.
Args:
length_normalization_fn: Closure for returning length normalization
parameter. Function accepts input as length, dtype and returns float.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
decoding_name: an optional name for the decoding loop tensors.
extra_cache_output: If true, the first cache will be in the states.
"""
self.length_normalization_fn = length_normalization_fn
self.dtype = tf.as_dtype(dtype)
self.decoding_name = decoding_name
def generate(self,
initial_ids: tf.Tensor,
initial_cache: Dict[str, tf.Tensor],
initial_log_probs: Optional[tf.Tensor] = None) -> Output:
"""Implements the decoding strategy (beam_search or sampling).
Args:
initial_ids: initial ids to pass into the symbols_to_logits_fn. int tensor
with shape [batch_size, 1]
initial_cache: dictionary for caching model outputs from previous step.
initial_log_probs: Optionally initial log probs if there is a prefix
sequence we want to start to decode from.
Returns:
Tuple of tensors representing
finished_sequence: shape [batch, max_seq_length]
finished_scores: [batch]
first_cache: The cache after init token
"""
batch_size = (
initial_ids.shape.as_list()[0]
if self.padded_decode else tf.shape(initial_ids)[0])
state, state_shapes = self._create_initial_state(initial_ids, initial_cache,
batch_size,
initial_log_probs)
def _generate_step(state):
topk_seq, topk_log_probs, topk_ids, new_cache = self._grow_alive_seq(
state, batch_size)
new_finished_flags = self._finished_flags(topk_ids, state)
alive_state = self._get_new_alive_state(topk_seq,
topk_log_probs,
new_finished_flags,
new_cache)
finished_state = self._get_new_finished_state(state,
topk_seq,
topk_log_probs,
new_finished_flags,
batch_size)
new_state = {
StateKeys.CUR_INDEX: state[StateKeys.CUR_INDEX] + 1
}
new_state.update(alive_state)
new_state.update(finished_state)
if self.extra_cache_output:
i = state[StateKeys.CUR_INDEX]
old_cache = state[StateKeys.INITIAL_OUTPUT_CACHE]
def update_with_cache(new_state, cache):
"""Updates new_state with cache."""
new_state.update({StateKeys.INITIAL_OUTPUT_CACHE: cache})
tf.cond(
tf.equal(i, 0), lambda: update_with_cache(new_state, new_cache),
lambda: update_with_cache(new_state, old_cache))
return [new_state]
finished_state = tf.nest.map_structure(
tf.stop_gradient,
tf.while_loop(
self._continue_search,
_generate_step,
loop_vars=[state],
shape_invariants=[state_shapes],
parallel_iterations=1,
name=self.decoding_name))
final_state = self._process_finished_state(finished_state[0])
return final_state
@abc.abstractmethod
def _create_initial_state(
self,
initial_ids: tf.Tensor,
initial_cache: Dict[str, tf.Tensor],
batch_size: int,
initial_log_probs: Optional[tf.Tensor] = None) -> InitialState:
"""Return initial state dictionary and its shape invariants."""
pass
@abc.abstractmethod
def _grow_alive_seq(self,
state: Dict[str, Any],
batch_size: int) -> InternalState:
"""Grow alive sequences by one token.
Args:
state: A dictionary with the current loop state.
batch_size: The given batch size
Returns:
Tuple of
(Top sequences,
Scores of returned sequences,
New ids,
New alive cache)
"""
pass
@abc.abstractmethod
def _get_new_alive_state(
self,
new_seq: tf.Tensor,
new_log_probs: tf.Tensor,
new_finished_flags: tf.Tensor,
new_cache: Dict[str, tf.Tensor]) -> Dict[str, Any]:
"""Gather the sequences that are still alive.
Args:
new_seq: New sequences generated by growing the current alive sequences
int32 tensor with shape
new_log_probs: Log probabilities of new sequences float32 tensor with
shape
new_finished_flags: A boolean Tensor indicates which sequences are live.
new_cache: Dict of cached values for each sequence.
Returns:
Dictionary with alive keys from StateKeys.
"""
pass
@abc.abstractmethod
def _get_new_finished_state(self,
state: Dict[str, Any],
new_seq: tf.Tensor,
new_log_probs: tf.Tensor,
new_finished_flags: tf.Tensor,
batch_size: int) -> Dict[str, tf.Tensor]:
"""Combine new and old finished sequences.
Args:
state: A dictionary with the current loop state.
new_seq: New sequences generated by growing the current alive sequences
int32 tensor.
new_log_probs: Log probabilities of new sequences float32 tensor with
shape.
new_finished_flags: A boolean Tensor indicates which sequences are live.
batch_size: The given batch size.
Returns:
Dictionary with finished keys from StateKeys.
"""
pass
@abc.abstractmethod
def _process_finished_state(self, finished_state: Dict[str, Any]) -> Output:
"""Process the alive/finished state to return final sequences and scores."""
pass
@abc.abstractmethod
def _continue_search(self, state: Dict[str, Any]) -> tf.Tensor:
"""Returns a bool tensor if the decoding loop should continue."""
pass
@abc.abstractmethod
def _finished_flags(self,
topk_ids: tf.Tensor,
state: Dict[str, Any]) -> tf.Tensor:
"""Calculate the finished flags."""
pass
def inf(self):
"""Returns a value close to infinity, but is still finite in `dtype`.
This is useful to get a very large value that is still zero when multiplied
by zero. The floating-point "Inf" value is NaN when multiplied by zero.
Returns:
A very large value.
"""
if self.dtype == dtypes.float32 or self.dtype == dtypes.bfloat16:
return 1e7
elif self.dtype == dtypes.float16:
return dtypes.float16.max
else:
raise AssertionError("Invalid dtype: %s" % self.dtype)